All Downloads are FREE. Search and download functionalities are using the official Maven repository.

opennlp.perceptron.SimplePerceptronSequenceTrainer Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package opennlp.perceptron;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import opennlp.model.AbstractModel;
import opennlp.model.DataIndexer;
import opennlp.model.Event;
import opennlp.model.IndexHashTable;
import opennlp.model.MutableContext;
import opennlp.model.OnePassDataIndexer;
import opennlp.model.Sequence;
import opennlp.model.SequenceStream;
import opennlp.model.SequenceStreamEventStream;

/**
 * Trains models for sequences using the perceptron algorithm.  Each outcome is represented as
 * a binary perceptron classifier.  This supports standard (integer) weighting as well
 * average weighting.  Sequence information is used in a simplified was to that described in:
 * Discriminative Training Methods for Hidden Markov Models: Theory and Experiments
 * with the Perceptron Algorithm. Michael Collins, EMNLP 2002.
 * Specifically only updates are applied to tokens which were incorrectly tagged by a sequence tagger
 * rather than to all feature across the sequence which differ from the training sequence.
 */
public class SimplePerceptronSequenceTrainer {

  private boolean printMessages = true;
  private int iterations;
  private SequenceStream sequenceStream;
  /** Number of events in the event set. */
  private int numEvents;

  /** Number of predicates. */
  private int numPreds; 
  private int numOutcomes;

  /** List of outcomes for each event i, in context[i]. */
  private int[] outcomeList;
  
  private String[] outcomeLabels;

  double[] modelDistribution;

  /** Stores the average parameter values of each predicate during iteration. */
  private MutableContext[] averageParams;
  
  /** Mapping between context and an integer */ 
  private IndexHashTable pmap;

  private Map omap;
  
  /** Stores the estimated parameter value of each predicate during iteration. */
  private MutableContext[] params;
  private boolean useAverage;
  private int[][][] updates;
  private int VALUE = 0;
  private int ITER = 1;
  private int EVENT = 2;
  
  private int[] allOutcomesPattern;
  private String[] predLabels;
  int numSequences;

  public AbstractModel trainModel(int iterations, SequenceStream sequenceStream, int cutoff, boolean useAverage) throws IOException {
    this.iterations = iterations;
    this.sequenceStream = sequenceStream;
    DataIndexer di = new OnePassDataIndexer(new SequenceStreamEventStream(sequenceStream),cutoff,false);
    numSequences = 0;
    for (Sequence s : sequenceStream) {
      numSequences++;
    }
    outcomeList  = di.getOutcomeList();
    predLabels = di.getPredLabels();
    pmap = new IndexHashTable(predLabels, 0.7d);
      
    display("Incorporating indexed data for training...  \n");
    this.useAverage = useAverage;
    numEvents = di.getNumEvents();

    this.iterations = iterations;
    outcomeLabels = di.getOutcomeLabels();
    omap = new HashMap();
    for (int oli=0;oli[] featureCounts = new Map[numOutcomes];
    for (int oi=0;oi();
    }
    PerceptronModel model = new PerceptronModel(params,predLabels,pmap,outcomeLabels);
    for (Sequence sequence : sequenceStream) {
      Event[] taggerEvents = sequenceStream.updateContext(sequence, model);
      Event[] events = sequence.getEvents();
      boolean update = false;
      for (int ei=0;ei "+averageParams[pi].getParameters()[oi]);
                updates[pi][oi][VALUE] = (int) params[pi].getParameters()[oi];
                updates[pi][oi][ITER] = iteration;
                updates[pi][oi][EVENT] = si;
              }
            }
          }
        }
        model = new PerceptronModel(params,predLabels,pmap,outcomeLabels);
      }
      si++;
    }
    //finish average computation
    double totIterations = (double) iterations*si;
    if (useAverage && iteration == iterations-1) {
      for (int pi = 0; pi < numPreds; pi++) {
        double[] predParams = averageParams[pi].getParameters();
        for (int oi = 0;oi "+averageParams[pi].getParameters()[oi]);
          }
        }
      }
    }
    display(". ("+numCorrect+"/"+numEvents+") "+((double) numCorrect / numEvents) + "\n");
  }
  
  private void trainingStats(MutableContext[] params) {
    int numCorrect = 0;
    int oei=0;
    for (Sequence sequence : sequenceStream) {
      Event[] taggerEvents = sequenceStream.updateContext(sequence, new PerceptronModel(params,predLabels,pmap,outcomeLabels));
      for (int ei=0;ei




© 2015 - 2024 Weber Informatics LLC | Privacy Policy