All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.MaxEntGERangeTrainer Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.classify;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.logging.Logger;

import cc.mallet.classify.constraints.ge.MaxEntGEConstraint;
import cc.mallet.classify.constraints.ge.MaxEntRangeL2FLGEConstraints;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;

/**
 * Training of MaxEnt models with labeled features using
 * Generalized Expectation Criteria.
 * 
 * Based on: 
 * "Learning from Labeled Features using Generalized Expectation Criteria"
 * Gregory Druck, Gideon Mann, Andrew McCallum
 * SIGIR 2008
 * 
 * @author Gregory Druck [email protected]
 * 
 * Better explanations of parameters is given in MaxEntOptimizableByGE
 */

public class MaxEntGERangeTrainer extends ClassifierTrainer implements ClassifierTrainer.ByOptimization, Boostable, Serializable {

  private static final long serialVersionUID = 1L;
  private static Logger logger = MalletLogger.getLogger(MaxEntGERangeTrainer.class.getName());
  private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntGERangeTrainer.class.getName()+"-pl");

  // these are for using this code from the command line
  private boolean normalize = true;
  private boolean useValues = false;
  private String constraintsFile;
  
  private int numIterations = 0;
  private int maxIterations = Integer.MAX_VALUE;
  private double temperature = 1;
  private double gaussianPriorVariance = 1;

  protected ArrayList constraints;
  private InstanceList trainingList = null;
  private MaxEnt classifier = null;
  private MaxEntOptimizableByGE ge = null;
  private Optimizer opt = null;

  public MaxEntGERangeTrainer() {}
  
  public MaxEntGERangeTrainer(ArrayList constraints) {
    this.constraints = constraints;
  }
  
  public MaxEntGERangeTrainer(ArrayList constraints, MaxEnt classifier) {
    this.constraints = constraints;
    this.classifier = classifier;
  }
  
  public void setConstraintsFile(String filename) {
    this.constraintsFile = filename;
  }
  
  public void setTemperature(double temp) {
    this.temperature = temp;
  }
  
  public void setGaussianPriorVariance(double variance) {
    this.gaussianPriorVariance = variance;
  }
  
  public MaxEnt getClassifier () {
    return classifier;
  }

  public void setUseValues(boolean flag) {
    this.useValues = flag;
  }
  
  public void setNormalize(boolean normalize) {
    this.normalize = normalize;
  }
  
  public Optimizable.ByGradientValue getOptimizable (InstanceList trainingList) {
    if (ge == null) {
      ge = new MaxEntOptimizableByGE(trainingList,constraints,classifier);
      ge.setTemperature(temperature);
      ge.setGaussianPriorVariance(gaussianPriorVariance);
    }
    return ge;
  }

  public Optimizer getOptimizer () {
    getOptimizable(trainingList);
    if (opt == null) {
      opt = new LimitedMemoryBFGS(ge);
    }
    return opt;
  }
  
  public void setOptimizer(Optimizer opt) { 
    this.opt = opt;
  }

  /**
   * Specifies the maximum number of iterations to run during a single call
   * to train or trainWithFeatureInduction.
   * @return This trainer
   */
  public void setMaxIterations (int iter) {
    maxIterations = iter;
  }
  
  public int getIteration () {
    return numIterations;
  }

  public MaxEnt train (InstanceList trainingList) {
    return train (trainingList, maxIterations);
  }

  public MaxEnt train (InstanceList train, int maxIterations) {
    trainingList = train;

    if (constraints == null && constraintsFile != null) {
      HashMap constraintsMap = 
        FeatureConstraintUtil.readRangeConstraintsFromFile(constraintsFile, trainingList);

      logger.info("number of constraints: " + constraintsMap.size());
      constraints = new ArrayList();

      MaxEntRangeL2FLGEConstraints geConstraints = new MaxEntRangeL2FLGEConstraints(train.getDataAlphabet().size(),
        train.getTargetAlphabet().size(),useValues,normalize);
      for (int fi : constraintsMap.keySet()) {
        double[][] dist = constraintsMap.get(fi);
        for (int li = 0; li < dist.length; li++) {
          if (!Double.isInfinite(dist[li][0])) {
            geConstraints.addConstraint(fi, li, dist[li][0], dist[li][1], 1);
          }
        }
      }
      constraints.add(geConstraints);
    }
    
    getOptimizable(trainingList);
    getOptimizer();
    
    if (opt instanceof LimitedMemoryBFGS) {
      ((LimitedMemoryBFGS)opt).reset();
    }    
    
    logger.fine ("trainingList.size() = "+trainingList.size());

    try {
      opt.optimize(maxIterations);
      numIterations += maxIterations;
    } catch (Exception e) {
      e.printStackTrace();
      logger.info ("Catching exception; saying converged.");
    }

    if (maxIterations == Integer.MAX_VALUE && opt instanceof LimitedMemoryBFGS) {
      // Run it again because in our and Sam Roweis' experience, BFGS can still
      // eke out more likelihood after first convergence by re-running without
      // being restricted by its gradient history.
      ((LimitedMemoryBFGS)opt).reset();
      try {
        opt.optimize(maxIterations);
        numIterations += maxIterations;
      } catch (Exception e) {
        e.printStackTrace();
        logger.info ("Catching exception; saying converged.");
      }
    }
    progressLogger.info("\n"); //  progress messages are on one line; move on.
    
    classifier = ge.getClassifier();
    return classifier;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy