opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.ml.perceptron;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import opennlp.tools.ml.AbstractEventModelSequenceTrainer;
import opennlp.tools.ml.model.AbstractDataIndexer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Sequence;
import opennlp.tools.ml.model.SequenceStream;
import opennlp.tools.ml.model.SequenceStreamEventStream;
/**
* Trains models for sequences using the perceptron algorithm. Each outcome is represented as
* a binary perceptron classifier. This supports standard (integer) weighting as well
* average weighting. Sequence information is used in a simplified was to that described in:
* Discriminative Training Methods for Hidden Markov Models: Theory and Experiments
* with the Perceptron Algorithm. Michael Collins, EMNLP 2002.
* Specifically only updates are applied to tokens which were incorrectly tagged by a sequence tagger
* rather than to all feature across the sequence which differ from the training sequence.
*/
public class SimplePerceptronSequenceTrainer extends AbstractEventModelSequenceTrainer {
public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
private int iterations;
private SequenceStream sequenceStream;
/** Number of events in the event set. */
private int numEvents;
/** Number of predicates. */
private int numPreds;
private int numOutcomes;
/** List of outcomes for each event i, in context[i]. */
private int[] outcomeList;
private String[] outcomeLabels;
/** Stores the average parameter values of each predicate during iteration. */
private MutableContext[] averageParams;
/** Mapping between context and an integer */
private Map pmap;
private Map omap;
/** Stores the estimated parameter value of each predicate during iteration. */
private MutableContext[] params;
private boolean useAverage;
private int[][][] updates;
private static final int VALUE = 0;
private static final int ITER = 1;
private static final int EVENT = 2;
private String[] predLabels;
private int numSequences;
public SimplePerceptronSequenceTrainer() {
}
@Override
public void validate() {
super.validate();
String algorithmName = getAlgorithm();
if (algorithmName != null) {
if (!PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName)) {
throw new IllegalArgumentException("algorithmName must be PERCEPTRON_SEQUENCE");
}
}
}
@Deprecated
@Override
public boolean isValid() {
try {
validate();
return true;
}
catch (IllegalArgumentException e) {
return false;
}
}
public AbstractModel doTrain(SequenceStream events) throws IOException {
int iterations = getIterations();
int cutoff = getCutoff();
boolean useAverage = trainingParameters.getBooleanParameter("UseAverage", true);
return trainModel(iterations, events, cutoff, useAverage);
}
// << members related to AbstractSequenceTrainer
public AbstractModel trainModel(int iterations, SequenceStream sequenceStream,
int cutoff, boolean useAverage) throws IOException {
this.iterations = iterations;
this.sequenceStream = sequenceStream;
trainingParameters.put(AbstractDataIndexer.CUTOFF_PARAM, cutoff);
trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false);
DataIndexer di = new OnePassDataIndexer();
di.init(trainingParameters,reportMap);
di.index(new SequenceStreamEventStream(sequenceStream));
numSequences = 0;
sequenceStream.reset();
while (sequenceStream.read() != null) {
numSequences++;
}
outcomeList = di.getOutcomeList();
predLabels = di.getPredLabels();
pmap = new HashMap<>();
for (int i = 0; i < predLabels.length; i++) {
pmap.put(predLabels[i], i);
}
display("Incorporating indexed data for training... \n");
this.useAverage = useAverage;
numEvents = di.getNumEvents();
this.iterations = iterations;
outcomeLabels = di.getOutcomeLabels();
omap = new HashMap<>();
for (int oli = 0; oli < outcomeLabels.length; oli++) {
omap.put(outcomeLabels[oli], oli);
}
outcomeList = di.getOutcomeList();
numPreds = predLabels.length;
numOutcomes = outcomeLabels.length;
if (useAverage) {
updates = new int[numPreds][numOutcomes][3];
}
display("done.\n");
display("\tNumber of Event Tokens: " + numEvents + "\n");
display("\t Number of Outcomes: " + numOutcomes + "\n");
display("\t Number of Predicates: " + numPreds + "\n");
params = new MutableContext[numPreds];
if (useAverage) averageParams = new MutableContext[numPreds];
int[] allOutcomesPattern = new int[numOutcomes];
for (int oi = 0; oi < numOutcomes; oi++) {
allOutcomesPattern[oi] = oi;
}
for (int pi = 0; pi < numPreds; pi++) {
params[pi] = new MutableContext(allOutcomesPattern, new double[numOutcomes]);
if (useAverage) averageParams[pi] = new MutableContext(allOutcomesPattern,new double[numOutcomes]);
for (int aoi = 0; aoi < numOutcomes; aoi++) {
params[pi].setParameter(aoi, 0.0);
if (useAverage) averageParams[pi].setParameter(aoi, 0.0);
}
}
double[] modelDistribution = new double[numOutcomes];
display("Computing model parameters...\n");
findParameters(iterations);
display("...done.\n");
/* Create and return the model ****/
String[] updatedPredLabels = predLabels;
if (useAverage) {
return new PerceptronModel(averageParams, updatedPredLabels, outcomeLabels);
}
else {
return new PerceptronModel(params, updatedPredLabels, outcomeLabels);
}
}
private void findParameters(int iterations) throws IOException {
display("Performing " + iterations + " iterations.\n");
for (int i = 1; i <= iterations; i++) {
if (i < 10)
display(" " + i + ": ");
else if (i < 100)
display(" " + i + ": ");
else
display(i + ": ");
nextIteration(i);
}
if (useAverage) {
trainingStats(averageParams);
}
else {
trainingStats(params);
}
}
public void nextIteration(int iteration) throws IOException {
iteration--; //move to 0-based index
int numCorrect = 0;
int oei = 0;
int si = 0;
List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy