opennlp.perceptron.SimplePerceptronSequenceTrainer Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package opennlp.perceptron;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import opennlp.model.AbstractModel;
import opennlp.model.DataIndexer;
import opennlp.model.Event;
import opennlp.model.MutableContext;
import opennlp.model.OnePassDataIndexer;
import opennlp.model.Sequence;
import opennlp.model.SequenceStream;
import opennlp.model.SequenceStreamEventStream;
/**
* Trains models for sequences using the perceptron algorithm. Each outcome is represented as
* a binary perceptron classifier. This supports standard (integer) weighting as well
* average weighting. Sequence information is used in a simplified was to that described in:
* Discriminative Training Methods for Hidden Markov Models: Theory and Experiments
* with the Perceptron Algorithm. Michael Collins, EMNLP 2002.
* Specifically only updates are applied to tokens which were incorrectly tagged by a sequence tagger
* rather than to all feature across the sequence which differ from the training sequence.
*/
public class SimplePerceptronSequenceTrainer {
private boolean printMessages = true;
private int iterations;
private SequenceStream sequenceStream;
/** Number of events in the event set. */
private int numEvents;
/** Number of predicates. */
private int numPreds;
private int numOutcomes;
/** List of outcomes for each event i, in context[i]. */
private int[] outcomeList;
private String[] outcomeLabels;
double[] modelDistribution;
/** Stores the average parameter values of each predicate during iteration. */
private MutableContext[] averageParams;
/** Mapping between context and an integer */
private Map pmap;
private Map omap;
/** Stores the estimated parameter value of each predicate during iteration. */
private MutableContext[] params;
private boolean useAverage;
private int[][][] updates;
private int VALUE = 0;
private int ITER = 1;
private int EVENT = 2;
private int[] allOutcomesPattern;
private String[] predLabels;
int numSequences;
public AbstractModel trainModel(int iterations, SequenceStream sequenceStream, int cutoff, boolean useAverage) throws IOException {
this.iterations = iterations;
this.sequenceStream = sequenceStream;
DataIndexer di = new OnePassDataIndexer(new SequenceStreamEventStream(sequenceStream),cutoff,false);
numSequences = 0;
for (Sequence s : sequenceStream) {
numSequences++;
}
outcomeList = di.getOutcomeList();
predLabels = di.getPredLabels();
pmap = new HashMap();
for (int pli=0;pli();
for (int oli=0;oli[] featureCounts = new Map[numOutcomes];
for (int oi=0;oi();
}
PerceptronModel model = new PerceptronModel(params,predLabels,pmap,outcomeLabels);
for (Sequence sequence : sequenceStream) {
Event[] taggerEvents = sequenceStream.updateContext(sequence, model);
Event[] events = sequence.getEvents();
boolean update = false;
for (int ei=0;ei "+averageParams[pi].getParameters()[oi]);
updates[pi][oi][VALUE] = (int) params[pi].getParameters()[oi];
updates[pi][oi][ITER] = iteration;
updates[pi][oi][EVENT] = si;
}
}
}
}
model = new PerceptronModel(params,predLabels,pmap,outcomeLabels);
}
si++;
}
//finish average computation
double totIterations = (double) iterations*si;
if (useAverage && iteration == iterations-1) {
for (int pi = 0; pi < numPreds; pi++) {
double[] predParams = averageParams[pi].getParameters();
for (int oi = 0;oi "+averageParams[pi].getParameters()[oi]);
}
}
}
}
display(". ("+numCorrect+"/"+numEvents+") "+((double) numCorrect / numEvents) + "\n");
}
private void trainingStats(MutableContext[] params) {
int numCorrect = 0;
int oei=0;
for (Sequence sequence : sequenceStream) {
Event[] taggerEvents = sequenceStream.updateContext(sequence, new PerceptronModel(params,predLabels,outcomeLabels));
for (int ei=0;ei