All Downloads are FREE. Search and download functionalities are using the official Maven repository.

opennlp.perceptron.PerceptronTrainer Maven / Gradle / Ivy

There is a newer version: 3.0.3
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package opennlp.perceptron;

import opennlp.model.AbstractModel;
import opennlp.model.DataIndexer;
import opennlp.model.EvalParameters;
import opennlp.model.MutableContext;

/**
 * Trains models using the perceptron algorithm.  Each outcome is represented as
 * a binary perceptron classifier.  This supports standard (integer) weighting as well
 * average weighting as described in:
 * Discriminative Training Methods for Hidden Markov Models: Theory and Experiments
 * with the Perceptron Algorithm. Michael Collins, EMNLP 2002.
 *
 */
public class PerceptronTrainer {

  /** Number of unique events which occurred in the event set. */
  private int numUniqueEvents;
  /** Number of events in the event set. */
  private int numEvents;
  
  /** Number of predicates. */
  private int numPreds; 
  /** Number of outcomes. */
  private int numOutcomes; 
  /** Records the array of predicates seen in each event. */
  private int[][] contexts;

  /** The value associates with each context. If null then context values are assumes to be 1. */
  private float[][] values;

  /** List of outcomes for each event i, in context[i]. */
  private int[] outcomeList;

  /** Records the num of times an event has been seen for each event i, in context[i]. */
  private int[] numTimesEventsSeen;
  
  /** Stores the String names of the outcomes.  The GIS only tracks outcomes
  as ints, and so this array is needed to save the model to disk and
  thereby allow users to know what the outcome was in human
  understandable terms. */
  private String[] outcomeLabels;

  /** Stores the String names of the predicates. The GIS only tracks
  predicates as ints, and so this array is needed to save the model to
  disk and thereby allow users to know what the outcome was in human
  understandable terms. */
  private String[] predLabels;

  /** Stores the estimated parameter value of each predicate during iteration. */
  private MutableContext[] params; 

  private int[][][] updates;
  private int VALUE = 0;
  private int ITER = 1;
  private int EVENT = 2;
  
  /** Stores the average parameter values of each predicate during iteration. */
  private MutableContext[] averageParams;

  private EvalParameters evalParams;

  private boolean printMessages = true;
    
  double[] modelDistribution;
  
  private int iterations;
  private boolean useAverage;
  
  public AbstractModel trainModel(int iterations, DataIndexer di, int cutoff) {
    this.iterations = iterations;
    return trainModel(iterations,di,cutoff,true);
  }
  
  public AbstractModel trainModel(int iterations, DataIndexer di, int cutoff, boolean useAverage) {
    display("Incorporating indexed data for training...  \n");
    this.useAverage = useAverage;
    contexts = di.getContexts();
    values = di.getValues();
    numTimesEventsSeen = di.getNumTimesEventsSeen();
    numEvents = di.getNumEvents();
    numUniqueEvents = contexts.length;

    this.iterations = iterations;
    outcomeLabels = di.getOutcomeLabels();
    outcomeList = di.getOutcomeList();

    predLabels = di.getPredLabels();
    numPreds = predLabels.length;
    numOutcomes = outcomeLabels.length;
    if (useAverage) updates = new int[numPreds][numOutcomes][3];
    
    display("done.\n");
    
    display("\tNumber of Event Tokens: " + numUniqueEvents + "\n");
    display("\t    Number of Outcomes: " + numOutcomes + "\n");
    display("\t  Number of Predicates: " + numPreds + "\n");
    

    params = new MutableContext[numPreds];
    if (useAverage) averageParams = new MutableContext[numPreds];
    evalParams = new EvalParameters(params,numOutcomes);
    
    int[] allOutcomesPattern= new int[numOutcomes];
    for (int oi = 0; oi < numOutcomes; oi++) {
      allOutcomesPattern[oi] = oi;
    }
    
    for (int pi = 0; pi < numPreds; pi++) {
      params[pi] = new MutableContext(allOutcomesPattern,new double[numOutcomes]);
      if (useAverage) 
        averageParams[pi] = new MutableContext(allOutcomesPattern,new double[numOutcomes]);
        for (int aoi=0;aoi modelDistribution[max]) 
            max = oi;

        for (int oi = 0;oi modelDistribution[max])
            max = oi;
        if (max == outcomeList[oei])
          numCorrect++;
      }
    }
    double trainingAccuracy = (double) numCorrect / numEvents;
    display(". (" + numCorrect + "/" + numEvents+") " + trainingAccuracy + "\n");
    return trainingAccuracy;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy