All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.sentiment.RNNTrainOptions Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.sentiment;

import java.io.Serializable;

public class RNNTrainOptions implements Serializable {

  public int batchSize = 27;

  /** Number of times through all the trees */
  public int epochs = 400;

  public int debugOutputEpochs = 8;

  public int maxTrainTimeSeconds = 60 * 60 * 24;

  public double learningRate = 0.01;

  public double scalingForInit = 1.0;

  private double[] classWeights = null;

  /**
   * The classWeights can be passed in as a comma separated list of
   * weights using the -classWeights flag.  If the classWeights are
   * not specified, the value is assumed to be 1.0.  classWeights only
   * apply at train time; we do not weight the classes at all during
   * test time.
   */
  public double getClassWeight(int i) {
    if (classWeights == null) {
      return 1.0;
    }
    return classWeights[i];
  }

  /** Regularization cost for the transform matrix  */
  public double regTransformMatrix = 0.001;

  /** Regularization cost for the classification matrices */
  public double regClassification = 0.0001;

  /** Regularization cost for the word vectors */
  public double regWordVector = 0.0001;

  /**
   * The value to set the learning rate for each parameter when initializing adagrad.
   */
  public double initialAdagradWeight = 0.0;

  /**
   * How many epochs between resets of the adagrad learning rates.
   * Set to 0 to never reset.
   */
  public int adagradResetFrequency = 1;

  /** Regularization cost for the transform tensor  */
  public double regTransformTensor = 0.001;

  /** 
   * Shuffle matrices when training.  Usually should be true.  Set to
   * false to compare training across different implementations, such
   * as with the original Matlab version 
   */
  public boolean shuffleMatrices = true;

  /**
   * If set, the initial matrices are logged to this location as a single file
   * using SentimentModel.toString()
   */
  public String initialMatrixLogPath = null;

  public int nThreads = 1;

  @Override
  public String toString() {
    StringBuilder result = new StringBuilder();
    result.append("TRAIN OPTIONS\n");
    result.append("batchSize=" + batchSize + "\n");
    result.append("epochs=" + epochs + "\n");
    result.append("debugOutputEpochs=" + debugOutputEpochs + "\n");
    result.append("maxTrainTimeSeconds=" + maxTrainTimeSeconds + "\n");
    result.append("learningRate=" + learningRate + "\n");
    result.append("scalingForInit=" + scalingForInit + "\n");
    if (classWeights == null) {
      result.append("classWeights=null\n");
    } else {
      result.append("classWeights=");
      result.append(classWeights[0]);
      for (int i = 1; i < classWeights.length; ++i) {
        result.append("," + classWeights[i]);
      }
      result.append("\n");
    }
    result.append("regTransformMatrix=" + regTransformMatrix + "\n");
    result.append("regTransformTensor=" + regTransformTensor + "\n");
    result.append("regClassification=" + regClassification + "\n");
    result.append("regWordVector=" + regWordVector + "\n");
    result.append("initialAdagradWeight=" + initialAdagradWeight + "\n");
    result.append("adagradResetFrequency=" + adagradResetFrequency + "\n");
    result.append("shuffleMatrices=" + shuffleMatrices + "\n");
    result.append("initialMatrixLogPath=" + initialMatrixLogPath + "\n");
    result.append("nThreads=" + nThreads + "\n");
    return result.toString();
  }

  public int setOption(String[] args, int argIndex) {
    if (args[argIndex].equalsIgnoreCase("-batchSize")) {
      batchSize = Integer.parseInt(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-epochs")) {
      epochs = Integer.parseInt(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-debugOutputEpochs")) {
      debugOutputEpochs = Integer.parseInt(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-maxTrainTimeSeconds")) {
      maxTrainTimeSeconds = Integer.parseInt(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-learningRate")) {
      learningRate = Double.parseDouble(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-scalingForInit")) {
      scalingForInit = Double.parseDouble(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-regTransformMatrix")) {
      regTransformMatrix = Double.parseDouble(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-regTransformTensor")) {
      regTransformTensor = Double.parseDouble(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-regClassification")) {
      regClassification = Double.parseDouble(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-regWordVector")) {
      regWordVector = Double.parseDouble(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-initialAdagradWeight")) {
      initialAdagradWeight = Double.parseDouble(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-adagradResetFrequency")) {
      adagradResetFrequency = Integer.parseInt(args[argIndex + 1]);
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-classWeights")) {
      String classWeightString = args[argIndex + 1];
      String[] pieces = classWeightString.split(",");
      classWeights = new double[pieces.length];
      for (int i = 0; i < pieces.length; ++i) {
        classWeights[i] = Double.parseDouble(pieces[i]);
      }
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-shuffleMatrices")) {
      shuffleMatrices = true;
      return argIndex + 1;
    } else if (args[argIndex].equalsIgnoreCase("-noShuffleMatrices")) {
      shuffleMatrices = false;
      return argIndex + 1;
    } else if (args[argIndex].equalsIgnoreCase("-initialMatrixLogPath")) {
      initialMatrixLogPath = args[argIndex + 1];
      return argIndex + 2;
    } else if (args[argIndex].equalsIgnoreCase("-nThreads") || args[argIndex].equalsIgnoreCase("-numThreads")) {
      nThreads = Integer.parseInt(args[argIndex + 1]);
      return argIndex + 2;
    } else {
      return argIndex;
    }
  }

  private static final long serialVersionUID = 1;
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy