All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.optimization.AbstractStochasticCachingDiffFunction Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.logging.Redwood;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;


/**
 *
 * @author Alex Kleeman
 */
public abstract class AbstractStochasticCachingDiffFunction extends AbstractCachingDiffFunction  {

  /** A logger for this class */
  private static final Redwood.RedwoodChannels log = Redwood.channels(AbstractStochasticCachingDiffFunction.class);

  public boolean hasNewVals = true;
  public boolean recalculatePrevBatch = false;
  public boolean returnPreviousValues = false;

  protected int lastBatchSize = 0;
  protected int[] lastBatch = null;
  protected int[] thisBatch = null;
  protected double[] lastXBatch = null;
  protected double[] lastVBatch = null;
  // protected double[] extFiniteDiffDerivative = null;

  protected int lastElement = 0;
  protected double[] HdotV = null;
  protected double[] gradPerturbed = null;
  protected double[] xPerturbed = null;
  protected int curElement = 0;

  protected List allIndices = null;
  protected Random randGenerator = new Random(1);//System.currentTimeMillis());

  protected boolean scaleUp = false;

  private int[] shuffledArray = null;

  public enum SamplingMethod {
    NoneSpecified,
    RandomWithReplacement,
    RandomWithoutReplacement,
    Ordered,
    Shuffled,
  }

  public void incrementRandom(int numTimes) {
    log.info("incrementing random "+numTimes+" times.");
    for (int i = 0; i < numTimes; i++) {
      randGenerator.nextInt(this.dataDimension());
    }
  }

  public void scaleUp(boolean toScaleUp){
    scaleUp = toScaleUp;
  }

  public StochasticCalculateMethods method = StochasticCalculateMethods.ExternalFiniteDifference;
  public SamplingMethod sampleMethod = SamplingMethod.RandomWithoutReplacement;

  /**
   * finiteDifferenceStepSize - this is the fixed step size for the finite difference approximation.
   *    a few tests were run using the SMD minimizer, and step sizes of 1e-4 to 1e-3 seemed to be ideal.
   * (akleeman)
   */
  protected double finiteDifferenceStepSize = 1e-4;


  /**
   * calculateStochastic needs to calculate a stochastic approximation to the derivative and value of
   *    of a function for a given batch of the data.  The approximation to the derivative must be stored
   *    in the array  derivative , the approximation to the value in  value 
   *    and the approximation to the Hessian vector product H.v in the array  HdotV .  Note
   *    that the hessian vector product is used primarily with the Stochastic Meta Descent optimization
   *    routine  SMDMinimizer .
   *
   *  Important: The stochastic approximation must be such that the sum of all stochastic calculations over
   *    each of the batches in the data must equal the full calculation.  i.e. for a data set of size 100
   *    the sum of the gradients for batches 1-10 , 11-20 , 21-30 .... 91-100 must be the same as the gradient
   *    for the full calculation (at the very least in expectation).  Be sure to take into account the priors.
   *
   *
   * @param x           -  value to evaluate at
   * @param v           -  the vector for the Hessian vector product H.v
   * @param batch       -  an array containing the indices of the data to use in the calculation, this array is being calculated
   *                        internal to the abstract, and only needs to be handled not generated by the implementation.
   */
  public abstract void calculateStochastic(double[] x, double[] v, int[] batch);


  /**
   * Data dimension must return the size of the data used by the function.
   */
  public abstract int dataDimension();

  /**
   * Clears the cache in a way that doesn't require reallocation :-).
   */
  @Override
  protected void clearCache() {
    super.clearCache();
    if (lastXBatch != null) lastXBatch[0] = Double.NaN;
    if (lastVBatch != null) lastVBatch[0] = Double.NaN;
  }

  @Override
  public double[] initial() {
    double[] initial = new double[domainDimension()];
    // Arrays.fill(initial, 0.0); // not needed; Java arrays zero initialized
    return initial;
  }

  /**
   * decrementBatch - This decrements the curElement variable by the amount batchSize.
   *  by decrementing the batch and then calling calculate you can recalculate over the previous batch.
   */
  public void decrementBatch(int batchSize){
    curElement -= batchSize;
    if(curElement < 0){curElement = 0;}
  }

  /**
   * incrementBatch will shift the curElement variable to mark the next batch.  It also resets the flags:
   *    hasNewElements
   *    recalculatePrevBatch
   *    returnPreviousValues
   */
  public void incrementBatch(int batchSize){
    curElement += batchSize;
    hasNewVals = false;
    recalculatePrevBatch = false;
    returnPreviousValues = false;
  }


  /**
   * getBatch is used to generate the next sequence of indices to be passed to the actual function.
   *  Depending on the current sample method this is done by:
   *    Ordered - simply generates the indices 1,2,3,4,....
   *    RandomWithReplacement - Samples uniformly from the set of possible indices
   *    RandomWithoutReplacement - Samples from the set of possible indices removing each used index, then restarting
   *          after each pass
   */

  //private int numCalls = 0;
  protected void getBatch(int batchSize){

//      if (numCalls == 0) {
//        for (int i = 0; i < 1538*\15; i++) {
//          randGenerator.nextInt(this.dataDimension());
//        }
//      }
//      numCalls++;

    if (thisBatch == null || thisBatch.length != batchSize){
      thisBatch = new int[batchSize];
    }

    //-----------------------------
    //Shuffled
    //-----------------------------
    if (sampleMethod == SamplingMethod.Shuffled) {
      if (shuffledArray == null) {
        shuffledArray = ArrayMath.range(0, this.dataDimension());
      }
      for(int i = 0; i  x = x + h*v
    for( int i = 0;i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy