All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.optimization.AbstractCachingDiffFunction Maven / Gradle / Ivy

Go to download

Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.

There is a newer version: 3.9.2
Show newest version
package edu.stanford.nlp.optimization;

import edu.stanford.nlp.util.logging.Redwood;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
import java.util.Random;
import java.util.Set;

/** A differentiable function that caches the last evaluation of its value and
 *  derivative.
 *
 *  @author Dan Klein
 */
public abstract class AbstractCachingDiffFunction implements DiffFunction, HasInitial  {

  /** A logger for this class */
  private static final Redwood.RedwoodChannels log = Redwood.channels(AbstractCachingDiffFunction.class);

  private double[] lastX; // = null;
  private int fEvaluations; // = 0;
  protected double[] derivative; // = null;
  protected double value; // = 0.0;
  private final Random generator = new Random(2147483647L);

  public boolean gradientCheck() {
    return gradientCheck(100, 50, initial());
  }

  public boolean gradientCheck(int numOfChecks, int numOfRandomChecks, double[] x) {
    double epsilon = 1e-5;
    double diffThreshold = 0.01;
    double diffPctThreshold = 0.1;
    double twoEpsilon = epsilon * 2;
    int xLen = x.length;
    // log.info("\n\n\ncalling derivativeAt");
    derivativeAt(x);
    double[] savedDeriv = new double[xLen];
    System.arraycopy(derivative, 0, savedDeriv, 0, derivative.length);
    int interval = Math.max(1, x.length / numOfChecks);
    Set indicesToCheck = new TreeSet<>();
    for (int paramIndex = 0; paramIndex < xLen; paramIndex+=interval) {
      indicesToCheck.add(paramIndex);
    }
    for (int i = xLen-1; i >= 0 && i > xLen-numOfChecks; i--) {
      indicesToCheck.add(i);
    }
    for (int i = 1; i < xLen && i < numOfChecks; i++) {
      indicesToCheck.add(i);
    }
    for (int i = 0; i < numOfRandomChecks; i++) {
      indicesToCheck.add(generator.nextInt(xLen));
    }
    boolean returnVal = true;
    List badIndices = new ArrayList<>();
    for (int paramIndex: indicesToCheck) {
      double oldX = x[paramIndex];
      x[paramIndex] = oldX + epsilon;
      // log.info("\n\n\ncalling valueAt1");
      double plusVal = valueAt(x);
      x[paramIndex] = oldX - epsilon;
      // log.info("\n\n\ncalling valueAt2");
      double minusVal = valueAt(x);
      double appDeriv = (plusVal - minusVal) / twoEpsilon;
      double calcDeriv = savedDeriv[paramIndex];
      double diff = Math.abs(appDeriv - calcDeriv);
      double pct = diff / Math.min(Math.abs(appDeriv), Math.abs(calcDeriv));
      if (diff > diffThreshold && pct > diffPctThreshold) {
        System.err.printf("Grad fail at %2d, appGrad=%9.7f, calcGrad=%9.7f, diff=%9.7f, pct=%9.7f\n", paramIndex,appDeriv,calcDeriv,diff,pct);
        badIndices.add(paramIndex);
        returnVal= false;
      } else {
        System.err.printf("Grad good at %2d, appGrad=%9.7f, calcGrad=%9.7f, diff=%9.7f, pct=%9.7f\n", paramIndex,appDeriv,calcDeriv,diff,pct);
      }
      x[paramIndex] = oldX;
    }
    if (returnVal){
      System.err.printf("ALL gradients passed. Yay!\n");
    } else {
      log.info("Bad indices: ");
      for (int i = 0; i < badIndices.size() && i < 10; ++i) {
        log.info(" " + badIndices.get(i));
      }
      if (badIndices.size() >= 10) {
        log.info(" (...)");
      }
      log.info();
    }
    return returnVal;
  }

  /**
   * Calculate the value at x and the derivative
   * and save them in the respective fields.
   *
   * @param x The point at which to calculate the function
   */
  protected abstract void calculate(double[] x);

  /**
   * Clears the cache in a way that doesn't require reallocation :-)
   */
  protected void clearCache() {
    if (lastX != null) lastX[0] = Double.NaN;
  }

  @Override
  public double[] initial() {
    double[] initial = new double[domainDimension()];
    // Arrays.fill(initial, 0.0); // You get zero fill of array for free in Java! (Like it or not....)
    return initial;
  }

  public double[] randomInitial() {
    double[] initial = new double[domainDimension()];
    for (int i = 0; i < initial.length; i++) {
      initial[i] = generator.nextDouble();
    }
    return initial;
  }

  protected static void copy(double[] copy, double[] orig) {
    System.arraycopy(orig, 0, copy, 0, orig.length);
  }

  public void ensure(double[] x) {
    if (Arrays.equals(x, lastX)) {
      return;
    }
    if (lastX == null) {
      lastX = new double[domainDimension()];
    }
    if (derivative == null) {
      derivative = new double[domainDimension()];
    }
    copy(lastX, x);
    fEvaluations += 1;
    calculate(x);
  }

  @Override
  public double valueAt(double[] x) {
    ensure(x);
    return value;
  }

  @Override
  public double[] derivativeAt(double[] x) {
    ensure(x);
    return derivative;
  }

  public double lastValue() {
    return value;
  }

  public double[] getDerivative() {
    return derivative;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy