All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.optimization.SGDToQNMinimizer Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.optimization; 
import edu.stanford.nlp.util.logging.Redwood;

import java.io.Serializable;

/**
 * Stochastic Gradient Descent To Quasi Newton Minimizer
 *
 * An experimental minimizer which takes a stochastic function (one implementing AbstractStochasticCachingDiffFunction)
 * and executes SGD for the first couple passes.  During the final iterations a series of approximate hessian vector
 * products are built up.  These are then passed to the QNminimizer so that it can start right up without the typical
 * delay.
 *
 * Note [2012] The basic idea here is good, but the original ScaledSGDMinimizer wasn't efficient, and so this would
 * be much more useful if rewritten to use the good StochasticInPlaceMinimizer instead.
 *
 * @author Alex Kleeman
 * @version 1.0
 * @since 1.0
 */
public class SGDToQNMinimizer implements Minimizer, Serializable   {

  /** A logger for this class */
  private static Redwood.RedwoodChannels log = Redwood.channels(SGDToQNMinimizer.class);

  private static final long serialVersionUID = -7551807670291500396L;

  // private int k;
  private final int bSize;
  private boolean quiet = false;

  public boolean outputIterationsToFile = false;
  // public int outputFrequency = 10;
  public double gain = 0.1;
  // private List gradList = null;
  // private List yList = null;
  // private List sList = null;
  // private List tmpYList = null;
  // private List tmpSList = null;
  // private int memory = 5;
  public int SGDPasses = -1;
  public int QNPasses = -1;
  private final int hessSampleSize;
  private final int QNMem;


  public SGDToQNMinimizer(double SGDGain, int batchSize, int SGDPasses, int QNPasses){
    this(SGDGain, batchSize, SGDPasses, QNPasses, 50, 10);
  }

  public SGDToQNMinimizer(double SGDGain, int batchSize, int sgdPasses, int qnPasses, int hessSamples, int QNMem) {
    this(SGDGain, batchSize, sgdPasses, qnPasses, hessSamples, QNMem, false);
  }

  public SGDToQNMinimizer(double SGDGain, int batchSize, int sgdPasses, int qnPasses, int hessSamples, int QNMem, boolean outputToFile) {
    this.gain = SGDGain;
    this.bSize = batchSize;
    this.SGDPasses = sgdPasses;
    this.QNPasses = qnPasses;
    this.hessSampleSize = hessSamples;
    this.QNMem = QNMem;
    this.outputIterationsToFile = outputToFile;
  }


 public void shutUp() {
    this.quiet = true;
  }

  protected String getName() {
    int g = (int) (gain * 1000);
    return "SGD2QN" + bSize + "_g" + g;
  }

  public double[] minimize(DiffFunction function, double functionTolerance, double[] initial) {
    return minimize(function,functionTolerance,initial,-1);
  }

  public double[] minimize(DiffFunction function, double functionTolerance, double[] initial, int maxIterations) {
    sayln("SGDToQNMinimizer called on function of " + function.domainDimension() + " variables;");

    // check for stochastic derivatives
    if (!(function instanceof AbstractStochasticCachingDiffFunction)) {
      throw new UnsupportedOperationException();
    }
    AbstractStochasticCachingDiffFunction dfunction = (AbstractStochasticCachingDiffFunction) function;

    dfunction.method = StochasticCalculateMethods.GradientOnly;

    ScaledSGDMinimizer sgd = new ScaledSGDMinimizer(this.gain,this.bSize,this.SGDPasses,1,this.outputIterationsToFile);
    QNMinimizer qn = new QNMinimizer(this.QNMem,true);

    double[] x = sgd.minimize(dfunction, functionTolerance, initial, this.SGDPasses);

    QNMinimizer.QNInfo qnInfo = qn.new QNInfo(sgd.sList , sgd.yList);
    qnInfo.d = sgd.diag;

    qn.minimize(dfunction, functionTolerance, x, this.QNPasses, qnInfo);

    log.info("");
    log.info("Minimization complete.");
    log.info("");
    log.info("Exiting for Debug");
    return x;
  }


  private void sayln(String s) {
    if (!quiet) {
      log.info(s);
    }
  }

  private void say(String s) {
    if (!quiet) {
      log.info(s);
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy