All Downloads are FREE. Search and download functionalities are using the official Maven repository.

opennlp.perceptron.PerceptronTrainer Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package opennlp.perceptron;

import opennlp.model.AbstractModel;
import opennlp.model.DataIndexer;
import opennlp.model.EvalParameters;
import opennlp.model.MutableContext;

/**
 * Trains models using the perceptron algorithm.  Each outcome is represented as
 * a binary perceptron classifier.  This supports standard (integer) weighting as well
 * average weighting as described in:
 * Discriminative Training Methods for Hidden Markov Models: Theory and Experiments
 * with the Perceptron Algorithm. Michael Collins, EMNLP 2002.
 *
 */
public class PerceptronTrainer {

  public static final double TOLERANCE_DEFAULT = .00001;
  
  /** Number of unique events which occurred in the event set. */
  private int numUniqueEvents;
  /** Number of events in the event set. */
  private int numEvents;
  
  /** Number of predicates. */
  private int numPreds; 
  /** Number of outcomes. */
  private int numOutcomes; 
  /** Records the array of predicates seen in each event. */
  private int[][] contexts;

  /** The value associates with each context. If null then context values are assumes to be 1. */
  private float[][] values;

  /** List of outcomes for each event i, in context[i]. */
  private int[] outcomeList;

  /** Records the num of times an event has been seen for each event i, in context[i]. */
  private int[] numTimesEventsSeen;
  
  /** Stores the String names of the outcomes.  The GIS only tracks outcomes
  as ints, and so this array is needed to save the model to disk and
  thereby allow users to know what the outcome was in human
  understandable terms. */
  private String[] outcomeLabels;

  /** Stores the String names of the predicates. The GIS only tracks
  predicates as ints, and so this array is needed to save the model to
  disk and thereby allow users to know what the outcome was in human
  understandable terms. */
  private String[] predLabels;

  private boolean printMessages = true;
  
  private double tolerance = TOLERANCE_DEFAULT;
  
  private Double stepSizeDecrease;
  
  private boolean useSkippedlAveraging;
  
  /**
   * Specifies the tolerance. If the change in training set accuracy
   * is less than this, stop iterating.
   * 
   * @param tolerance
   */
  public void setTolerance(double tolerance) {
    
    if (tolerance < 0) {
      throw new
          IllegalArgumentException("tolerance must be a positive number but is " + tolerance + "!");
    }
    
    this.tolerance = tolerance;
  }

  /**
   * Enables and sets step size decrease. The step size is
   * decreased every iteration by the specified value.
   * 
   * @param decrease - step size decrease in percent
   */
  public void setStepSizeDecrease(double decrease) {
    
    if (decrease < 0 || decrease > 100) {
      throw new
          IllegalArgumentException("decrease must be between 0 and 100 but is " + decrease + "!");
    }
    
    stepSizeDecrease = decrease;
  }
  
  /**
   * Enables skipped averaging, this flag changes the standard
   * averaging to special averaging instead.
   * 

* If we are doing averaging, and the current iteration is one * of the first 20 or it is a perfect square, then updated the * summed parameters. *

* The reason we don't take all of them is that the parameters change * less toward the end of training, so they drown out the contributions * of the more volatile early iterations. The use of perfect * squares allows us to sample from successively farther apart iterations. * * @param averaging averaging flag */ public void setSkippedAveraging(boolean averaging) { useSkippedlAveraging = averaging; } public AbstractModel trainModel(int iterations, DataIndexer di, int cutoff) { return trainModel(iterations,di,cutoff,true); } public AbstractModel trainModel(int iterations, DataIndexer di, int cutoff, boolean useAverage) { display("Incorporating indexed data for training... \n"); contexts = di.getContexts(); values = di.getValues(); numTimesEventsSeen = di.getNumTimesEventsSeen(); numEvents = di.getNumEvents(); numUniqueEvents = contexts.length; outcomeLabels = di.getOutcomeLabels(); outcomeList = di.getOutcomeList(); predLabels = di.getPredLabels(); numPreds = predLabels.length; numOutcomes = outcomeLabels.length; display("done.\n"); display("\tNumber of Event Tokens: " + numUniqueEvents + "\n"); display("\t Number of Outcomes: " + numOutcomes + "\n"); display("\t Number of Predicates: " + numPreds + "\n"); display("Computing model parameters...\n"); MutableContext[] finalParameters = findParameters(iterations, useAverage); display("...done.\n"); /*************** Create and return the model ******************/ return new PerceptronModel(finalParameters, predLabels, outcomeLabels); } private MutableContext[] findParameters (int iterations, boolean useAverage) { display("Performing " + iterations + " iterations.\n"); int[] allOutcomesPattern= new int[numOutcomes]; for (int oi = 0; oi < numOutcomes; oi++) allOutcomesPattern[oi] = oi; /** Stores the estimated parameter value of each predicate during iteration. */ MutableContext[] params = new MutableContext[numPreds]; for (int pi = 0; pi < numPreds; pi++) { params[pi] = new MutableContext(allOutcomesPattern,new double[numOutcomes]); for (int aoi=0;aoi values[max]) max = i; return max; } private void display (String s) { if (printMessages) System.out.print(s); } private void displayIteration (int i) { if (i > 10 && (i%10) != 0) return; if (i < 10) display(" " + i + ": "); else if (i < 100) display(" " + i + ": "); else display(i + ": "); } // See whether a number is a perfect square. Inefficient, but fine // for our purposes. private final static boolean isPerfectSquare (int n) { int root = (int)Math.sqrt(n); return root*root == n; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy