edu.stanford.nlp.optimization.AbstractCachingDiffFunction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-parser Show documentation
Show all versions of stanford-parser Show documentation
Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.
package edu.stanford.nlp.optimization;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
import java.util.Random;
import java.util.Set;
/** A differentiable function that caches the last evaluation of its value and
* derivative.
*
* @author Dan Klein
*/
public abstract class AbstractCachingDiffFunction implements DiffFunction, HasInitial {
/** A logger for this class */
private static final Redwood.RedwoodChannels log = Redwood.channels(AbstractCachingDiffFunction.class);
private double[] lastX; // = null;
private int fEvaluations; // = 0;
protected double[] derivative; // = null;
protected double value; // = 0.0;
private final Random generator = new Random(2147483647L);
public boolean gradientCheck() {
return gradientCheck(100, 50, initial());
}
public boolean gradientCheck(int numOfChecks, int numOfRandomChecks, double[] x) {
double epsilon = 1e-5;
double diffThreshold = 0.01;
double diffPctThreshold = 0.1;
double twoEpsilon = epsilon * 2;
int xLen = x.length;
// log.info("\n\n\ncalling derivativeAt");
derivativeAt(x);
double[] savedDeriv = new double[xLen];
System.arraycopy(derivative, 0, savedDeriv, 0, derivative.length);
int interval = Math.max(1, x.length / numOfChecks);
Set indicesToCheck = new TreeSet<>();
for (int paramIndex = 0; paramIndex < xLen; paramIndex+=interval) {
indicesToCheck.add(paramIndex);
}
for (int i = xLen-1; i >= 0 && i > xLen-numOfChecks; i--) {
indicesToCheck.add(i);
}
for (int i = 1; i < xLen && i < numOfChecks; i++) {
indicesToCheck.add(i);
}
for (int i = 0; i < numOfRandomChecks; i++) {
indicesToCheck.add(generator.nextInt(xLen));
}
boolean returnVal = true;
List badIndices = new ArrayList<>();
for (int paramIndex: indicesToCheck) {
double oldX = x[paramIndex];
x[paramIndex] = oldX + epsilon;
// log.info("\n\n\ncalling valueAt1");
double plusVal = valueAt(x);
x[paramIndex] = oldX - epsilon;
// log.info("\n\n\ncalling valueAt2");
double minusVal = valueAt(x);
double appDeriv = (plusVal - minusVal) / twoEpsilon;
double calcDeriv = savedDeriv[paramIndex];
double diff = Math.abs(appDeriv - calcDeriv);
double pct = diff / Math.min(Math.abs(appDeriv), Math.abs(calcDeriv));
if (diff > diffThreshold && pct > diffPctThreshold) {
System.err.printf("Grad fail at %2d, appGrad=%9.7f, calcGrad=%9.7f, diff=%9.7f, pct=%9.7f\n", paramIndex,appDeriv,calcDeriv,diff,pct);
badIndices.add(paramIndex);
returnVal= false;
} else {
System.err.printf("Grad good at %2d, appGrad=%9.7f, calcGrad=%9.7f, diff=%9.7f, pct=%9.7f\n", paramIndex,appDeriv,calcDeriv,diff,pct);
}
x[paramIndex] = oldX;
}
if (returnVal){
System.err.printf("ALL gradients passed. Yay!\n");
} else {
log.info("Bad indices: ");
for (int i = 0; i < badIndices.size() && i < 10; ++i) {
log.info(" " + badIndices.get(i));
}
if (badIndices.size() >= 10) {
log.info(" (...)");
}
log.info();
}
return returnVal;
}
/**
* Calculate the value at x and the derivative
* and save them in the respective fields.
*
* @param x The point at which to calculate the function
*/
protected abstract void calculate(double[] x);
/**
* Clears the cache in a way that doesn't require reallocation :-)
*/
protected void clearCache() {
if (lastX != null) lastX[0] = Double.NaN;
}
@Override
public double[] initial() {
double[] initial = new double[domainDimension()];
// Arrays.fill(initial, 0.0); // You get zero fill of array for free in Java! (Like it or not....)
return initial;
}
public double[] randomInitial() {
double[] initial = new double[domainDimension()];
for (int i = 0; i < initial.length; i++) {
initial[i] = generator.nextDouble();
}
return initial;
}
protected static void copy(double[] copy, double[] orig) {
System.arraycopy(orig, 0, copy, 0, orig.length);
}
public void ensure(double[] x) {
if (Arrays.equals(x, lastX)) {
return;
}
if (lastX == null) {
lastX = new double[domainDimension()];
}
if (derivative == null) {
derivative = new double[domainDimension()];
}
copy(lastX, x);
fEvaluations += 1;
calculate(x);
}
@Override
public double valueAt(double[] x) {
ensure(x);
return value;
}
@Override
public double[] derivativeAt(double[] x) {
ensure(x);
return derivative;
}
public double lastValue() {
return value;
}
public double[] getDerivative() {
return derivative;
}
}