edu.stanford.nlp.optimization.AbstractStochasticCachingDiffFunction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-corenlp Show documentation
Show all versions of stanford-corenlp Show documentation
Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.
package edu.stanford.nlp.optimization;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
/**
*
* @author Alex Kleeman
*/
public abstract class AbstractStochasticCachingDiffFunction extends AbstractCachingDiffFunction {
/** A logger for this class */
private static final Redwood.RedwoodChannels log = Redwood.channels(AbstractStochasticCachingDiffFunction.class);
public boolean hasNewVals = true;
public boolean recalculatePrevBatch = false;
public boolean returnPreviousValues = false;
protected int lastBatchSize = 0;
protected int[] lastBatch = null;
protected int[] thisBatch = null;
protected double[] lastXBatch = null;
protected double[] lastVBatch = null;
// protected double[] extFiniteDiffDerivative = null;
protected int lastElement = 0;
protected double[] HdotV = null;
protected double[] gradPerturbed = null;
protected double[] xPerturbed = null;
protected int curElement = 0;
protected List allIndices = null;
protected Random randGenerator = new Random(1);//System.currentTimeMillis());
protected boolean scaleUp = false;
private int[] shuffledArray = null;
public enum SamplingMethod {
NoneSpecified,
RandomWithReplacement,
RandomWithoutReplacement,
Ordered,
Shuffled,
}
public void incrementRandom(int numTimes) {
log.info("incrementing random "+numTimes+" times.");
for (int i = 0; i < numTimes; i++) {
randGenerator.nextInt(this.dataDimension());
}
}
public void scaleUp(boolean toScaleUp){
scaleUp = toScaleUp;
}
public StochasticCalculateMethods method = StochasticCalculateMethods.ExternalFiniteDifference;
public SamplingMethod sampleMethod = SamplingMethod.RandomWithoutReplacement;
/**
* finiteDifferenceStepSize - this is the fixed step size for the finite difference approximation.
* a few tests were run using the SMD minimizer, and step sizes of 1e-4 to 1e-3 seemed to be ideal.
* (akleeman)
*/
protected double finiteDifferenceStepSize = 1e-4;
/**
* calculateStochastic needs to calculate a stochastic approximation to the derivative and value of
* of a function for a given batch of the data. The approximation to the derivative must be stored
* in the array derivative
, the approximation to the value in value
* and the approximation to the Hessian vector product H.v in the array HdotV
. Note
* that the hessian vector product is used primarily with the Stochastic Meta Descent optimization
* routine SMDMinimizer
.
*
* Important: The stochastic approximation must be such that the sum of all stochastic calculations over
* each of the batches in the data must equal the full calculation. i.e. for a data set of size 100
* the sum of the gradients for batches 1-10 , 11-20 , 21-30 .... 91-100 must be the same as the gradient
* for the full calculation (at the very least in expectation). Be sure to take into account the priors.
*
*
* @param x - value to evaluate at
* @param v - the vector for the Hessian vector product H.v
* @param batch - an array containing the indices of the data to use in the calculation, this array is being calculated
* internal to the abstract, and only needs to be handled not generated by the implementation.
*/
public abstract void calculateStochastic(double[] x, double[] v, int[] batch);
/**
* Data dimension must return the size of the data used by the function.
*/
public abstract int dataDimension();
/**
* Clears the cache in a way that doesn't require reallocation :-).
*/
@Override
protected void clearCache() {
super.clearCache();
if (lastXBatch != null) lastXBatch[0] = Double.NaN;
if (lastVBatch != null) lastVBatch[0] = Double.NaN;
}
@Override
public double[] initial() {
double[] initial = new double[domainDimension()];
// Arrays.fill(initial, 0.0); // not needed; Java arrays zero initialized
return initial;
}
/**
* decrementBatch - This decrements the curElement variable by the amount batchSize.
* by decrementing the batch and then calling calculate you can recalculate over the previous batch.
*/
public void decrementBatch(int batchSize){
curElement -= batchSize;
if(curElement < 0){curElement = 0;}
}
/**
* incrementBatch will shift the curElement variable to mark the next batch. It also resets the flags:
* hasNewElements
* recalculatePrevBatch
* returnPreviousValues
*/
public void incrementBatch(int batchSize){
curElement += batchSize;
hasNewVals = false;
recalculatePrevBatch = false;
returnPreviousValues = false;
}
/**
* getBatch is used to generate the next sequence of indices to be passed to the actual function.
* Depending on the current sample method this is done by:
* Ordered - simply generates the indices 1,2,3,4,....
* RandomWithReplacement - Samples uniformly from the set of possible indices
* RandomWithoutReplacement - Samples from the set of possible indices removing each used index, then restarting
* after each pass
*/
//private int numCalls = 0;
protected void getBatch(int batchSize){
// if (numCalls == 0) {
// for (int i = 0; i < 1538*\15; i++) {
// randGenerator.nextInt(this.dataDimension());
// }
// }
// numCalls++;
if (thisBatch == null || thisBatch.length != batchSize){
thisBatch = new int[batchSize];
}
//-----------------------------
//Shuffled
//-----------------------------
if (sampleMethod == SamplingMethod.Shuffled) {
if (shuffledArray == null) {
shuffledArray = ArrayMath.range(0, this.dataDimension());
}
for(int i = 0; i();
for(int i=0;i this.dataDimension()){
Collections.shuffle(Collections.singletonList(allIndices),randGenerator); //Shuffle if we got to the end of the list
}
//watch out for overflow
curElement = (curElement + batchSize) % allIndices.size(); //Rollover
} else {
throw new IllegalStateException(" NO SAMPLING METHOD SELECTED"); } } void stochasticEnsure(double[] x, double[] v, int batchSize){ if (lastXBatch="=" null) { lastXBatch="new" double[domainDimension()]; log.info("Setting previous position (x)."); } if (lastVBatch="=" null) { lastVBatch="new" double[domainDimension()]; log.info("Setting previous gain (v)"); } if (derivative="=" null) { derivative="new" double[domainDimension()]; log.info("Setting Derivative."); } if (HdotV="=" null) { HdotV="new" double[domainDimension()]; log.info("Setting HdotV."); } if (lastBatch="=" null){ lastBatch="new" int[batchSize]; log.info("Setting last batch"); } If we want to recalculate using the previous batch if(recalculatePrevBatch && batchSize="=" lastBatch.length){ thisBatch="lastBatch;" }else{ * If we dont want to calculate anything we just want the last values. This is especially usefull if you know the values have already been calculated, and you don't want to waste time comparing the entire array of x's and v's. * if(returnPreviousValues){ returnPreviousValues="false;" return; } If we dont know there are new values, and we havnt asked to recalculate then compare to avoid needing to recalculate if( !hasNewVals && lastElement!="curElement" ){ if ((lastBatchSize="=batchSize)" && Arrays.equals(x, lastXBatch) && Arrays.equals(v,lastVBatch) && Arrays.equals(thisBatch,lastBatch)) { return; } } getBatch(batchSize); } copy(lastXBatch,x); if(lastBatch.length !="batchSize){" lastBatch="new" int[batchSize]; } System.arraycopy(thisBatch,0,lastBatch,0,thisBatch.length); if(v!="null){copy(lastVBatch,v);}" lastBatchSize="batchSize;" calculateStochastic(x,v,thisBatch); This is used to make the function evaluation equal the true function in expectation. if(scaleUp){ double ratio="(" (double) this.dataDimension()) ( (double)batchSize) ; for(int i="0;i x = x + h*v
for( int i = 0;i