All Downloads are FREE. Search and download functionalities are using the official Maven repository.

opennlp.tools.ml.maxent.GISTrainer Maven / Gradle / Ivy

There is a newer version: 2.5.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package opennlp.tools.ml.maxent;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.EvalParameters;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Prior;
import opennlp.tools.ml.model.UniformPrior;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

/**
 * An implementation of Generalized Iterative Scaling (GIS).
 * 

* The reference paper for this implementation was Adwait Ratnaparkhi's tech report at the * University of Pennsylvania's Institute for Research in Cognitive Science, * and is available at * ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z. *

* The slack parameter used in the above implementation has been removed by default * from the computation and a method for updating with Gaussian smoothing has been * added per Investigating GIS and Smoothing for Maximum Entropy Taggers, Clark and Curran (2002). * http://acl.ldc.upenn.edu/E/E03/E03-1071.pdf. *

* The slack parameter can be used by setting {@code useSlackParameter} to {@code true}. * Gaussian smoothing can be used by setting {@code useGaussianSmoothing} to {@code true}. *

* A {@link Prior} can be used to train models which converge to the distribution which minimizes the * relative entropy between the distribution specified by the empirical constraints of the training * data and the specified prior. By default, the uniform distribution is used as the prior. */ public class GISTrainer extends AbstractEventTrainer { private static final Logger logger = LoggerFactory.getLogger(GISTrainer.class); public static final String LOG_LIKELIHOOD_THRESHOLD_PARAM = "LLThreshold"; public static final double LOG_LIKELIHOOD_THRESHOLD_DEFAULT = 0.0001; private double llThreshold = 0.0001; /** * Specifies whether unseen context/outcome pairs should be estimated as occur very infrequently. */ private boolean useSimpleSmoothing = false; /** * Specified whether parameter updates should prefer a distribution of parameters which * is gaussian. */ private boolean useGaussianSmoothing = false; private double sigma = 2.0; // If we are using smoothing, this is used as the "number" of // times we want the trainer to imagine that it saw a feature that it // actually didn't see. Defaulted to 0.1. private double _smoothingObservation = 0.1; /** * Number of unique events which occurred in the event set. */ private int numUniqueEvents; /** * Number of predicates. */ private int numPreds; /** * Number of outcomes. */ private int numOutcomes; /** * Records the array of predicates seen in each event. */ private int[][] contexts; /** * The value associated with each context. If null then context values are assumes to be 1. */ private float[][] values; /** * List of outcomes for each event i, in context[i]. */ private int[] outcomeList; /** * Records the num of times an event has been seen for each event i, in context[i]. */ private int[] numTimesEventsSeen; /** * Stores the String names of the outcomes. The GIS only tracks outcomes as * ints, and so this array is needed to save the model to disk and thereby * allow users to know what the outcome was in human understandable terms. */ private String[] outcomeLabels; /** * Stores the String names of the predicates. The GIS only tracks predicates * as ints, and so this array is needed to save the model to disk and thereby * allow users to know what the outcome was in human understandable terms. */ private String[] predLabels; /** * Stores the observed expected values of the features based on training data. */ private MutableContext[] observedExpects; /** * Stores the estimated parameter value of each predicate during iteration */ private MutableContext[] params; /** * Stores the expected values of the features based on the current models */ private MutableContext[][] modelExpects; /** * This is the prior distribution that the model uses for training. */ private Prior prior; /** * Initial probability for all outcomes. */ private EvalParameters evalParams; public static final String MAXENT_VALUE = "MAXENT"; /** * If we are using smoothing, this is used as the "number" of times we want * the trainer to imagine that it saw a feature that it actually didn't see. * Defaulted to 0.1. */ private static final String SMOOTHING_PARAM = "Smoothing"; private static final boolean SMOOTHING_DEFAULT = false; private static final String SMOOTHING_OBSERVATION_PARAM = "SmoothingObservation"; private static final double SMOOTHING_OBSERVATION = 0.1; private static final String GAUSSIAN_SMOOTHING_PARAM = "GaussianSmoothing"; private static final boolean GAUSSIAN_SMOOTHING_DEFAULT = false; private static final String GAUSSIAN_SMOOTHING_SIGMA_PARAM = "GaussianSmoothingSigma"; private static final double GAUSSIAN_SMOOTHING_SIGMA_DEFAULT = 2.0; /** * Initializes a {@link GISTrainer}. *

* Note:
* The resulting instance does not print progress messages about training to STDOUT. */ public GISTrainer() { } /** * {@inheritDoc} */ @Override public boolean isSortAndMerge() { return true; } /** * {@inheritDoc} */ @Override public void init(TrainingParameters trainingParameters, Map reportMap) { super.init(trainingParameters, reportMap); llThreshold = trainingParameters.getDoubleParameter(LOG_LIKELIHOOD_THRESHOLD_PARAM, LOG_LIKELIHOOD_THRESHOLD_DEFAULT); useSimpleSmoothing = trainingParameters.getBooleanParameter(SMOOTHING_PARAM, SMOOTHING_DEFAULT); if (useSimpleSmoothing) { _smoothingObservation = trainingParameters.getDoubleParameter(SMOOTHING_OBSERVATION_PARAM, SMOOTHING_OBSERVATION); } useGaussianSmoothing = trainingParameters.getBooleanParameter(GAUSSIAN_SMOOTHING_PARAM, GAUSSIAN_SMOOTHING_DEFAULT); if (useGaussianSmoothing) { sigma = trainingParameters.getDoubleParameter( GAUSSIAN_SMOOTHING_SIGMA_PARAM, GAUSSIAN_SMOOTHING_SIGMA_DEFAULT); } if (useSimpleSmoothing && useGaussianSmoothing) throw new RuntimeException("Cannot set both Gaussian smoothing and Simple smoothing"); } /** * {@inheritDoc} */ @Override public MaxentModel doTrain(DataIndexer indexer) throws IOException { int iterations = getIterations(); int threads = trainingParameters.getIntParameter(TrainingParameters.THREADS_PARAM, 1); return trainModel(iterations, indexer, threads); } /** * Sets whether this trainer will use smoothing while training the model. *

* Note:
* This can improve model accuracy, though training will potentially take * longer and use more memory. Model size will also be larger. * * @param smooth {@code true} if smoothing is desired, {@code false} if not. */ public void setSmoothing(boolean smooth) { useSimpleSmoothing = smooth; } /** * Sets whether this trainer will use smoothing while training the model. *

* Note:
* This can improve model accuracy, though training will potentially take * longer and use more memory. Model size will also be larger. * * @param timesSeen The "number" of times we want the trainer to imagine * it saw a feature that it actually didn't see */ public void setSmoothingObservation(double timesSeen) { _smoothingObservation = timesSeen; } /** * Sets whether this trainer will use smoothing while training the model. *

* Note:
* This can improve model accuracy, though training will potentially take * longer and use more memory. Model size will also be larger. * * @param sigmaValue The Gaussian sigma value used for smoothing. */ public void setGaussianSigma(double sigmaValue) { useGaussianSmoothing = true; sigma = sigmaValue; } /** * Trains a model using the GIS algorithm, assuming 100 iterations and no * cutoff. * * @param eventStream The {@link ObjectStream eventStream} holding the data * on which this model will be trained. * * @return A trained {@link GISModel} which can be used immediately or saved to * disk using an {@link opennlp.tools.ml.maxent.io.GISModelWriter}. */ public GISModel trainModel(ObjectStream eventStream) throws IOException { return trainModel(eventStream, 100, 0); } /** * Trains a GIS model on the event in the specified event stream, using the specified number * of iterations and the specified count cutoff. * * @param eventStream A {@link ObjectStream stream} of all events. * @param iterations The number of iterations to use for GIS. * @param cutoff The number of times a feature must occur to be included. * * @return A trained {@link GISModel} which can be used immediately or saved to * disk using an {@link opennlp.tools.ml.maxent.io.GISModelWriter}. */ public GISModel trainModel(ObjectStream eventStream, int iterations, int cutoff) throws IOException { DataIndexer indexer = new OnePassDataIndexer(); TrainingParameters indexingParameters = new TrainingParameters(); indexingParameters.put(GISTrainer.CUTOFF_PARAM, cutoff); indexingParameters.put(GISTrainer.ITERATIONS_PARAM, iterations); Map reportMap = new HashMap<>(); indexer.init(indexingParameters, reportMap); indexer.index(eventStream); return trainModel(iterations, indexer); } /** * Trains a model using the GIS algorithm. * * @param iterations The number of GIS iterations to perform. * @param di The {@link DataIndexer} used to compress events in memory. * * @return A trained {@link GISModel} which can be used immediately or saved to * disk using an {@link opennlp.tools.ml.maxent.io.GISModelWriter}. * @throws IllegalArgumentException Thrown if parameters were invalid. */ public GISModel trainModel(int iterations, DataIndexer di) { return trainModel(iterations, di, new UniformPrior(), 1); } /** * Trains a model using the GIS algorithm. * * @param iterations The number of GIS iterations to perform. * @param di The {@link DataIndexer} used to compress events in memory. * @param threads The number of thread to train with. Must be greater than {@code 0}. * * @return A trained {@link GISModel} which can be used immediately or saved to * disk using an {@link opennlp.tools.ml.maxent.io.GISModelWriter}. * @throws IllegalArgumentException Thrown if parameters were invalid. */ public GISModel trainModel(int iterations, DataIndexer di, int threads) { return trainModel(iterations, di, new UniformPrior(), threads); } /** * Trains a model using the GIS algorithm. * * @param iterations The number of GIS iterations to perform. * @param di The {@link DataIndexer} used to compress events in memory. * @param modelPrior The {@link Prior} distribution used to train this model. * * @return A trained {@link GISModel} which can be used immediately or saved to * disk using an {@link opennlp.tools.ml.maxent.io.GISModelWriter}. * @throws IllegalArgumentException Thrown if parameters were invalid. */ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int threads) { if (threads <= 0) { throw new IllegalArgumentException("threads must be at least one or greater but is " + threads + "!"); } modelExpects = new MutableContext[threads][]; /* Incorporate all of the needed info *****/ logger.info("Incorporating indexed data for training..."); contexts = di.getContexts(); values = di.getValues(); /* The number of times a predicate occurred in the training data. */ int[] predicateCounts = di.getPredCounts(); numTimesEventsSeen = di.getNumTimesEventsSeen(); numUniqueEvents = contexts.length; this.prior = modelPrior; //printTable(contexts); // determine the correction constant and its inverse double correctionConstant = 0; for (int ci = 0; ci < contexts.length; ci++) { if (values == null || values[ci] == null) { if (contexts[ci].length > correctionConstant) { correctionConstant = contexts[ci].length; } } else { float cl = values[ci][0]; for (int vi = 1; vi < values[ci].length; vi++) { cl += values[ci][vi]; } if (cl > correctionConstant) { correctionConstant = cl; } } } logger.info("done."); outcomeLabels = di.getOutcomeLabels(); outcomeList = di.getOutcomeList(); numOutcomes = outcomeLabels.length; predLabels = di.getPredLabels(); prior.setLabels(outcomeLabels, predLabels); numPreds = predLabels.length; logger.info("\tNumber of Event Tokens: {} " + "\n\t Number of Outcomes: {} " + "\n\t Number of Predicates: {}", numUniqueEvents, numOutcomes, numPreds); // set up feature arrays float[][] predCount = new float[numPreds][numOutcomes]; for (int ti = 0; ti < numUniqueEvents; ti++) { for (int j = 0; j < contexts[ti].length; j++) { if (values != null && values[ti] != null) { predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti] * values[ti][j]; } else { predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti]; } } } // A fake "observation" to cover features which are not detected in // the data. The default is to assume that we observed "1/10th" of a // feature during training. final double smoothingObservation = _smoothingObservation; // Get the observed expectations of the features. Strictly speaking, // we should divide the counts by the number of Tokens, but because of // the way the model's expectations are approximated in the // implementation, this is cancelled out when we compute the next // iteration of a parameter, making the extra divisions wasteful. params = new MutableContext[numPreds]; for (int i = 0; i < modelExpects.length; i++) { modelExpects[i] = new MutableContext[numPreds]; } observedExpects = new MutableContext[numPreds]; // The model does need the correction constant and the correction feature. The correction constant // is only needed during training, and the correction feature is not necessary. // For compatibility reasons the model contains form now on a correction constant of 1, // and a correction param 0. evalParams = new EvalParameters(params, numOutcomes); int[] activeOutcomes = new int[numOutcomes]; int[] outcomePattern; int[] allOutcomesPattern = new int[numOutcomes]; for (int oi = 0; oi < numOutcomes; oi++) { allOutcomesPattern[oi] = oi; } int numActiveOutcomes; for (int pi = 0; pi < numPreds; pi++) { numActiveOutcomes = 0; if (useSimpleSmoothing) { numActiveOutcomes = numOutcomes; outcomePattern = allOutcomesPattern; } else { //determine active outcomes for (int oi = 0; oi < numOutcomes; oi++) { if (predCount[pi][oi] > 0) { activeOutcomes[numActiveOutcomes] = oi; numActiveOutcomes++; } } if (numActiveOutcomes == numOutcomes) { outcomePattern = allOutcomesPattern; } else { outcomePattern = new int[numActiveOutcomes]; System.arraycopy(activeOutcomes, 0, outcomePattern, 0, numActiveOutcomes); } } params[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]); for (int i = 0; i < modelExpects.length; i++) { modelExpects[i][pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]); } observedExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]); for (int aoi = 0; aoi < numActiveOutcomes; aoi++) { int oi = outcomePattern[aoi]; params[pi].setParameter(aoi, 0.0); for (MutableContext[] modelExpect : modelExpects) { modelExpect[pi].setParameter(aoi, 0.0); } if (predCount[pi][oi] > 0) { observedExpects[pi].setParameter(aoi, predCount[pi][oi]); } else if (useSimpleSmoothing) { observedExpects[pi].setParameter(aoi, smoothingObservation); } } } logger.info("...done."); /* Find the parameters *****/ if (threads == 1) { logger.info("Computing model parameters ..."); } else { logger.info("Computing model parameters in {} threads...", threads); } findParameters(iterations, correctionConstant); // Create and return the model return new GISModel(params, predLabels, outcomeLabels); } /* Estimate and return the model parameters. */ private void findParameters(int iterations, double correctionConstant) { int threads = modelExpects.length; ExecutorService executor = Executors.newFixedThreadPool(threads, runnable -> { Thread thread = new Thread(runnable); thread.setName("opennlp.tools.ml.maxent.ModelExpectationComputeTask.nextIteration()"); thread.setDaemon(true); return thread; }); CompletionService completionService = new ExecutorCompletionService<>(executor); double prevLL = 0.0; double currLL; logger.info("Performing {} iterations.", iterations); for (int i = 1; i <= iterations; i++) { currLL = nextIteration(correctionConstant, completionService, i); if (i > 1) { if (prevLL > currLL) { logger.warn("Model Diverging: loglikelihood decreased"); break; } if (currLL - prevLL < llThreshold) { break; } } prevLL = currLL; } // kill a bunch of these big objects now that we don't need them observedExpects = null; modelExpects = null; numTimesEventsSeen = null; contexts = null; executor.shutdown(); } //modeled on implementation in Zhang Le's maxent kit private double gaussianUpdate(int predicate, int oid, double correctionConstant) { double param = params[predicate].getParameters()[oid]; double x0 = 0.0; double modelValue = modelExpects[0][predicate].getParameters()[oid]; double observedValue = observedExpects[predicate].getParameters()[oid]; for (int i = 0; i < 50; i++) { double tmp = modelValue * StrictMath.exp(correctionConstant * x0); double f = tmp + (param + x0) / sigma - observedValue; double fp = tmp * correctionConstant + 1 / sigma; if (fp == 0) { break; } double x = x0 - f / fp; if (StrictMath.abs(x - x0) < 0.000001) { x0 = x; break; } x0 = x; } return x0; } /* Compute one iteration of GIS and return log-likelihood.*/ private double nextIteration(double correctionConstant, CompletionService completionService, int iteration) { // compute contribution of p(a|b_i) for each feature and the new // correction parameter double loglikelihood = 0.0; int numEvents = 0; int numCorrect = 0; // Each thread gets equal number of tasks, if the number of tasks // is not divisible by the number of threads, the first "leftOver" // threads have one extra task. int numberOfThreads = modelExpects.length; int taskSize = numUniqueEvents / numberOfThreads; int leftOver = numUniqueEvents % numberOfThreads; // submit all tasks to the completion service. for (int i = 0; i < numberOfThreads; i++) { if (i < leftOver) { completionService.submit(new ModelExpectationComputeTask(i, i * taskSize + i, taskSize + 1)); } else { completionService.submit(new ModelExpectationComputeTask(i, i * taskSize + leftOver, taskSize)); } } for (int i = 0; i < numberOfThreads; i++) { ModelExpectationComputeTask finishedTask; try { finishedTask = completionService.take().get(); } catch (InterruptedException e) { // TODO: We got interrupted, but that is currently not really supported! // For now we fail hard. We hopefully soon // handle this case properly! throw new IllegalStateException("Interruption is not supported!", e); } catch (ExecutionException e) { // Only runtime exception can be thrown during training, if one was thrown // it should be re-thrown. That could for example be a NullPointerException // which is caused through a bug in our implementation. throw new RuntimeException("Exception during training: " + e.getMessage(), e); } // When they are done, retrieve the results ... numEvents += finishedTask.getNumEvents(); numCorrect += finishedTask.getNumCorrect(); loglikelihood += finishedTask.getLoglikelihood(); } // merge the results of the two computations for (int pi = 0; pi < numPreds; pi++) { int[] activeOutcomes = params[pi].getOutcomes(); for (int aoi = 0; aoi < activeOutcomes.length; aoi++) { for (int i = 1; i < modelExpects.length; i++) { modelExpects[0][pi].updateParameter(aoi, modelExpects[i][pi].getParameters()[aoi]); } } } // compute the new parameter values for (int pi = 0; pi < numPreds; pi++) { double[] observed = observedExpects[pi].getParameters(); double[] model = modelExpects[0][pi].getParameters(); int[] activeOutcomes = params[pi].getOutcomes(); for (int aoi = 0; aoi < activeOutcomes.length; aoi++) { if (useGaussianSmoothing) { params[pi].updateParameter(aoi, gaussianUpdate(pi, aoi, correctionConstant)); } else { if (model[aoi] == 0) { logger.warn("Model expects == 0 for {} {}", predLabels[pi], outcomeLabels[aoi]); } //params[pi].updateParameter(aoi,(StrictMath.log(observed[aoi]) - StrictMath.log(model[aoi]))); params[pi].updateParameter(aoi, ((StrictMath.log(observed[aoi]) - StrictMath.log(model[aoi])) / correctionConstant)); } for (MutableContext[] modelExpect : modelExpects) { modelExpect[pi].setParameter(aoi, 0.0); // re-initialize to 0.0's } } } logger.info("{} - loglikelihood={}\t{}", iteration, loglikelihood, ((double) numCorrect / numEvents)); return loglikelihood; } private class ModelExpectationComputeTask implements Callable { private final int startIndex; private final int length; final private int threadIndex; private double loglikelihood = 0; private int numEvents = 0; private int numCorrect = 0; // startIndex to compute, number of events to compute ModelExpectationComputeTask(int threadIndex, int startIndex, int length) { this.startIndex = startIndex; this.length = length; this.threadIndex = threadIndex; } public ModelExpectationComputeTask call() { final double[] modelDistribution = new double[numOutcomes]; for (int ei = startIndex; ei < startIndex + length; ei++) { // TODO: check interruption status here, if interrupted set a poisoned flag and return if (values != null) { prior.logPrior(modelDistribution, contexts[ei], values[ei]); GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams); } else { prior.logPrior(modelDistribution, contexts[ei]); GISModel.eval(contexts[ei], modelDistribution, evalParams); } for (int j = 0; j < contexts[ei].length; j++) { int pi = contexts[ei][j]; int[] activeOutcomes = modelExpects[threadIndex][pi].getOutcomes(); for (int aoi = 0; aoi < activeOutcomes.length; aoi++) { int oi = activeOutcomes[aoi]; // numTimesEventsSeen must also be thread safe if (values != null && values[ei] != null) { modelExpects[threadIndex][pi].updateParameter(aoi, modelDistribution[oi] * values[ei][j] * numTimesEventsSeen[ei]); } else { modelExpects[threadIndex][pi].updateParameter(aoi, modelDistribution[oi] * numTimesEventsSeen[ei]); } } } loglikelihood += StrictMath.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei]; numEvents += numTimesEventsSeen[ei]; int max = ArrayMath.argmax(modelDistribution); if (max == outcomeList[ei]) { numCorrect += numTimesEventsSeen[ei]; } } return this; } synchronized int getNumEvents() { return numEvents; } synchronized int getNumCorrect() { return numCorrect; } synchronized double getLoglikelihood() { return loglikelihood; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy