All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.classify.SVMLightClassifierFactory Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.classify;

import edu.stanford.nlp.optimization.GoldenSectionLineSearch;
import edu.stanford.nlp.stats.*;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.LineSearcher;

import java.io.*;
import java.text.NumberFormat;
import java.util.*;
import java.util.function.Function;
import java.util.regex.Pattern;


import edu.stanford.nlp.util.logging.Redwood;

/**
 * This class is meant for training SVMs ({@link SVMLightClassifier}s).  It actually calls SVM Light, or
 * SVM Struct for multiclass SVMs, or SVM perf is the option is enabled, on the command line, reads in the produced
 * model file and creates a Linear Classifier.  A Platt model is also trained
 * (unless otherwise specified) on top of the SVM so that probabilities can
 * be produced. For multiclass classifier, you have to set C using setC otherwise the code will not run (by sonalg).
 *
 * @author Jenny Finkel
 * @author Aria Haghighi
 * @author Sarah Spikes ([email protected]) (templatization)
 */

public class SVMLightClassifierFactory implements ClassifierFactory>{ //extends AbstractLinearClassifierFactory {

  /**
   *
   */
  private static final long serialVersionUID = 1L;

  /**
   * C can be tuned using held-out set or cross-validation
   * For binary SVM, if C=0, svmlight uses default of 1/(avg x*x) 
   */
  protected double C = -1.0;
  private boolean useSigmoid = false;
  protected boolean verbose = true;
  private String svmLightLearn = "/u/nlp/packages/svm_light/svm_learn";
  private String svmStructLearn = "/u/nlp/packages/svm_multiclass/svm_multiclass_learn";
  private String svmPerfLearn = "/u/nlp/packages/svm_perf/svm_perf_learn";
  private String svmLightClassify = "/u/nlp/packages/svm_light/svm_classify";
  private String svmStructClassify = "/u/nlp/packages/svm_multiclass/svm_multiclass_classify";
  private String svmPerfClassify = "/u/nlp/packages/svm_perf/svm_perf_classify";

  private boolean useAlphaFile = false;
  protected File alphaFile;
  private boolean deleteTempFilesOnExit = true;
  private int svmLightVerbosity = 0;  // not verbose
  private boolean doEval = false;
  private boolean useSVMPerf = false;

  final static Redwood.RedwoodChannels logger = Redwood.channels(SVMLightClassifierFactory.class);

  /** @param svmLightLearn is the fullPathname of the training program of svmLight with default value "/u/nlp/packages/svm_light/svm_learn"
   * @param svmStructLearn is the fullPathname of the training program of svmMultiClass with default value "/u/nlp/packages/svm_multiclass/svm_multiclass_learn"
   * @param svmPerfLearn is the fullPathname of the training program of svmMultiClass with default value "/u/nlp/packages/svm_perf/svm_perf_learn"
   */
  public SVMLightClassifierFactory(String svmLightLearn, String svmStructLearn, String svmPerfLearn){
    this.svmLightLearn = svmLightLearn;
    this.svmStructLearn = svmStructLearn;
    this.svmPerfLearn = svmPerfLearn;
  }

  public SVMLightClassifierFactory(){
  }

  public SVMLightClassifierFactory(boolean useSVMPerf){
    this.useSVMPerf = useSVMPerf;
  }

  /**
   * Set the C parameter (for the slack variables) for training the SVM.
   */
  public void setC(double C) {
    this.C = C;
  }

  /**
   * Get the C parameter (for the slack variables) for training the SVM.
   */

  public double getC() {
    return C;
  }

  /**
   * Specify whether or not to train an overlying platt (sigmoid)
   * model for producing meaningful probabilities.
   */
  public void setUseSigmoid(boolean useSigmoid) {
    this.useSigmoid = useSigmoid;
  }

  /**
   * Get whether or not to train an overlying platt (sigmoid)
   * model for producing meaningful probabilities.
   */
  public boolean getUseSigma() {
    return useSigmoid;
  }


  public boolean getDeleteTempFilesOnExitFlag() {
    return deleteTempFilesOnExit;
  }

  public void setDeleteTempFilesOnExitFlag(boolean deleteTempFilesOnExit) {
    this.deleteTempFilesOnExit = deleteTempFilesOnExit;
  }

  /**
   * Reads in a model file in svm light format.  It needs to know if its multiclass or not
   * because it affects the number of header lines.  Maybe there is another way to tell and we
   * can remove this flag?
   */
  private static Pair> readModel(File modelFile, boolean multiclass) {
    int modelLineCount = 0;
    try {

      int numLinesToSkip = multiclass ? 13 : 10;
      String stopToken   = "#";

      BufferedReader in = new BufferedReader(new FileReader(modelFile));

      for (int i=0; i < numLinesToSkip; i++) { 
        in.readLine();
        modelLineCount ++;
      }

      List>> supportVectors = new ArrayList<>();
      // Read Threshold
      String thresholdLine = in.readLine();
      modelLineCount ++;
      String[] pieces = thresholdLine.split("\\s+");
      double threshold = Double.parseDouble(pieces[0]);
      // Read Support Vectors
      while (in.ready()) {
        String svLine = in.readLine();
        modelLineCount ++;
        pieces = svLine.split("\\s+");
        // First Element is the alpha_i * y_i
        double  alpha = Double.parseDouble(pieces[0]);
        ClassicCounter supportVector  = new ClassicCounter<>();
        for (int i=1; i < pieces.length; ++i) {
          String piece = pieces[i];
          if (piece.equals(stopToken)) break;
          // Each in featureIndex:num class
          String[] indexNum = piece.split(":");
          String featureIndex = indexNum[0];
          // mihai: we may see "qid" as indexNum[0]. just skip this piece. this is the block id useful only for reranking, which we don't do here.
          if(! featureIndex.equals("qid")){
            double count = Double.parseDouble(indexNum[1]);
            supportVector.incrementCount(Integer.valueOf(featureIndex), count);
          }
        }
        supportVectors.add(new Pair<>(alpha, supportVector));
      }

      in.close();

      return new Pair<>(threshold, getWeights(supportVectors));
    }
    catch (Exception e) {
      e.printStackTrace();
      throw new RuntimeException("Error reading SVM model (line " + modelLineCount + " in file " + modelFile.getAbsolutePath() + ")");
    }
  }

  /**
   * Takes all the support vectors, and their corresponding alphas, and computes a weight
   * vector that can be used in a vanilla LinearClassifier.  This only works because
   * we are using a linear kernel.  The Counter is over the feature indices (+1 cos for
   * some reason svm_light is 1-indexed), not features.
   */
  private static ClassicCounter getWeights(List>> supportVectors) {
    ClassicCounter weights = new ClassicCounter<>();
    for (Pair> sv : supportVectors) {
      ClassicCounter c = new ClassicCounter<>(sv.second());
      Counters.multiplyInPlace(c, sv.first());
      Counters.addInPlace(weights, c);
    }
    return weights;
  }

  /**
   * Converts the weight Counter to be from indexed, svm_light format, to a format
   * we can use in our LinearClassifier.
   */
  private ClassicCounter> convertWeights(ClassicCounter weights, Index featureIndex, Index labelIndex, boolean multiclass) {
    return multiclass ? convertSVMStructWeights(weights, featureIndex, labelIndex) : convertSVMLightWeights(weights, featureIndex, labelIndex);
  }

  /**
   * Converts the svm_light weight Counter (which uses feature indices) into a weight Counter
   * using the actual features and labels.  Because this is svm_light, and not svm_struct, the
   * weights for the +1 class (which correspond to labelIndex.get(0)) and the -1 class
   * (which correspond to labelIndex.get(1)) are just the negation of one another.
   */
  private ClassicCounter> convertSVMLightWeights(ClassicCounter weights, Index featureIndex, Index labelIndex) {
    ClassicCounter> newWeights = new ClassicCounter<>();
    for (int i : weights.keySet()) {
      F f = featureIndex.get(i-1);
      double w = weights.getCount(i);
      // the first guy in the labelIndex was the +1 class and the second guy
      // was the -1 class
      newWeights.incrementCount(new Pair<>(f, labelIndex.get(0)),w);
      newWeights.incrementCount(new Pair<>(f, labelIndex.get(1)),-w);
    }
    return newWeights;
  }

  /**
   * Converts the svm_struct weight Counter (in which the weight for a feature/label pair
   * correspondes to ((labelIndex * numFeatures)+(featureIndex+1))) into a weight Counter
   * using the actual features and labels.
   */
  private ClassicCounter> convertSVMStructWeights(ClassicCounter weights, Index featureIndex, Index labelIndex) {
    // int numLabels = labelIndex.size();
    int numFeatures = featureIndex.size();
    ClassicCounter> newWeights = new ClassicCounter<>();
    for (int i : weights.keySet()) {
      L l = labelIndex.get((i-1) / numFeatures); // integer division on purpose
      F f = featureIndex.get((i-1) % numFeatures);
      double w = weights.getCount(i);
      newWeights.incrementCount(new Pair<>(f, l),w);
    }

    return newWeights;
  }

  /**
   * Builds a sigmoid model to turn the classifier outputs into probabilities.
   */
  private LinearClassifier fitSigmoid(SVMLightClassifier classifier, GeneralDataset dataset) {
    RVFDataset plattDataset = new RVFDataset<>();
    for (int i = 0; i < dataset.size(); i++) {
      RVFDatum d = dataset.getRVFDatum(i);
      Counter scores = classifier.scoresOf((Datum)d);
      scores.incrementCount(null);
      plattDataset.add(new RVFDatum<>(scores, d.label()));
    }
    LinearClassifierFactory factory = new LinearClassifierFactory<>();
    factory.setPrior(new LogPrior(LogPrior.LogPriorType.NULL));
    return factory.trainClassifier(plattDataset);
  }

  /**
   * This method will cross validate on the given data and number of folds
   * to find the optimal C.  The scorer is how you determine what to
   * optimize for (F-score, accuracy, etc).  The C is then saved, so that
   * if you train a classifier after calling this method, that C will be used.
   */
  public void crossValidateSetC(GeneralDataset dataset, int numFolds, final Scorer scorer, LineSearcher minimizer) {
    System.out.println("in Cross Validate");

    useAlphaFile = true;
    boolean oldUseSigmoid = useSigmoid;
    useSigmoid = false;

    final CrossValidator crossValidator = new CrossValidator<>(dataset, numFolds);
    final Function,GeneralDataset,CrossValidator.SavedState>,Double> score =
        fold -> {
          GeneralDataset trainSet = fold.first();
          GeneralDataset devSet = fold.second();
          alphaFile = (File)fold.third().state;
          //train(trainSet,true,true);
          SVMLightClassifier classifier = trainClassifierBasic(trainSet);
          fold.third().state = alphaFile;
          return scorer.score(classifier,devSet);
        };

    Function negativeScorer =
        cToTry -> {
          C = cToTry;
          if (verbose) { System.out.print("C = "+cToTry+" "); }
          Double averageScore = crossValidator.computeAverage(score);
          if (verbose) { System.out.println(" -> average Score: "+averageScore); }
          return -averageScore;
        };

    C = minimizer.minimize(negativeScorer);

    useAlphaFile = false;
    useSigmoid = oldUseSigmoid;
  }

  public void heldOutSetC(GeneralDataset train, double percentHeldOut, final Scorer scorer, LineSearcher minimizer) {
    Pair, GeneralDataset> data = train.split(percentHeldOut);
    heldOutSetC(data.first(), data.second(), scorer, minimizer);
  }

  /**
   * This method will cross validate on the given data and number of folds
   * to find the optimal C.  The scorer is how you determine what to
   * optimize for (F-score, accuracy, etc).  The C is then saved, so that
   * if you train a classifier after calling this method, that C will be used.
   */
  public void heldOutSetC(final GeneralDataset trainSet, final GeneralDataset devSet, final Scorer scorer, LineSearcher minimizer) {

    useAlphaFile = true;
    boolean oldUseSigmoid = useSigmoid;
    useSigmoid = false;

    Function negativeScorer =
        cToTry -> {
          C = cToTry;
          SVMLightClassifier classifier = trainClassifierBasic(trainSet);
          double score = scorer.score(classifier,devSet);
          return -score;
        };

    C = minimizer.minimize(negativeScorer);

    useAlphaFile = false;
    useSigmoid = oldUseSigmoid;
  }

  private boolean tuneHeldOut = false;
  private boolean tuneCV = false;
  private Scorer scorer = new MultiClassAccuracyStats<>();
  private LineSearcher tuneMinimizer = new GoldenSectionLineSearch(true);
  private int folds;
  private double heldOutPercent;

  public double getHeldOutPercent() {
    return heldOutPercent;
  }

  public void setHeldOutPercent(double heldOutPercent) {
    this.heldOutPercent = heldOutPercent;
  }

  public int getFolds() {
    return folds;
  }

  public void setFolds(int folds) {
    this.folds = folds;
  }

  public LineSearcher getTuneMinimizer() {
    return tuneMinimizer;
  }

  public void setTuneMinimizer(LineSearcher minimizer) {
    this.tuneMinimizer = minimizer;
  }

  public Scorer getScorer() {
    return scorer;
  }

  public void setScorer(Scorer scorer) {
    this.scorer = scorer;
  }

  public boolean getTuneCV() {
    return tuneCV;
  }

  public void setTuneCV(boolean tuneCV) {
    this.tuneCV = tuneCV;
  }

  public boolean getTuneHeldOut() {
    return tuneHeldOut;
  }

  public void setTuneHeldOut(boolean tuneHeldOut) {
    this.tuneHeldOut = tuneHeldOut;
  }

  public int getSvmLightVerbosity() {
    return svmLightVerbosity;
  }

  public void setSvmLightVerbosity(int svmLightVerbosity) {
    this.svmLightVerbosity = svmLightVerbosity;
  }

  public SVMLightClassifier trainClassifier(GeneralDataset dataset) {
    if (tuneHeldOut) {
      heldOutSetC(dataset, heldOutPercent, scorer, tuneMinimizer);
    } else if (tuneCV) {
      crossValidateSetC(dataset, folds, scorer, tuneMinimizer);
    }
    return trainClassifierBasic(dataset);
  }

  Pattern whitespacePattern = Pattern.compile("\\s+");

  public SVMLightClassifier trainClassifierBasic(GeneralDataset dataset) {
    Index labelIndex = dataset.labelIndex();
    Index featureIndex = dataset.featureIndex;
    boolean multiclass = (dataset.numClasses() > 2);
    try {

      // this is the file that the model will be saved to
      File modelFile = File.createTempFile("svm-", ".model");
      if (deleteTempFilesOnExit) {
        modelFile.deleteOnExit();
      }

      // this is the file that the svm light formated dataset
      // will be printed to
      File dataFile = File.createTempFile("svm-", ".data");
      if (deleteTempFilesOnExit) {
        dataFile.deleteOnExit();
      }

      // print the dataset
      PrintWriter pw = new PrintWriter(new FileWriter(dataFile));
      dataset.printSVMLightFormat(pw);
      pw.close();

      // -v 0 makes it not verbose
      // -m 400 gives it a larger cache, for faster training
      String cmd = (multiclass ? svmStructLearn : (useSVMPerf ? svmPerfLearn : svmLightLearn)) + " -v " + svmLightVerbosity + " -m 400 ";

      // set the value of C if we have one specified
      if (C > 0.0) cmd = cmd + " -c " + C + " ";  // C value
      else if(useSVMPerf) cmd = cmd + " -c " + 0.01 + " "; //It's required to specify this parameter for SVM perf

      // Alpha File
      if (useAlphaFile) {
        File newAlphaFile = File.createTempFile("svm-", ".alphas");
        if (deleteTempFilesOnExit) {
          newAlphaFile.deleteOnExit();
        }
        cmd = cmd + " -a " + newAlphaFile.getAbsolutePath();
        if (alphaFile != null) {
          cmd = cmd + " -y " + alphaFile.getAbsolutePath();
        }
        alphaFile = newAlphaFile;
      }

      // File and Model Data
      cmd = cmd + " " + dataFile.getAbsolutePath() + " " + modelFile.getAbsolutePath();

      if (verbose) logger.info("<< "+cmd+" >>");

      /*Process p = Runtime.getRuntime().exec(cmd);

      p.waitFor();

      if (p.exitValue() != 0) throw new RuntimeException("Error Training SVM Light exit value: " + p.exitValue());
      p.destroy();   */
      SystemUtils.run(new ProcessBuilder(whitespacePattern.split(cmd)),
        new PrintWriter(System.err), new PrintWriter(System.err));

      if (doEval) {
        File predictFile = File.createTempFile("svm-", ".pred");
        if (deleteTempFilesOnExit) {
          predictFile.deleteOnExit();
        }
        String evalCmd = (multiclass ? svmStructClassify : (useSVMPerf ? svmPerfClassify : svmLightClassify)) + " "
                + dataFile.getAbsolutePath() + " " + modelFile.getAbsolutePath() + " " + predictFile.getAbsolutePath();
        if (verbose) logger.info("<< " + evalCmd + " >>");
        SystemUtils.run(new ProcessBuilder(whitespacePattern.split(evalCmd)),
                new PrintWriter(System.err), new PrintWriter(System.err));
      }
      // read in the model file
      Pair> weightsAndThresh = readModel(modelFile, multiclass);
      double threshold = weightsAndThresh.first();
      ClassicCounter> weights = convertWeights(weightsAndThresh.second(), featureIndex, labelIndex, multiclass);
      ClassicCounter thresholds = new ClassicCounter<>();
      if (!multiclass) {
        thresholds.setCount(labelIndex.get(0), -threshold);
        thresholds.setCount(labelIndex.get(1), threshold);
      }
      SVMLightClassifier classifier = new SVMLightClassifier<>(weights, thresholds);
      if (doEval) {
        File predictFile = File.createTempFile("svm-", ".pred2");
        if (deleteTempFilesOnExit) {
          predictFile.deleteOnExit();
        }
        PrintWriter pw2 = new PrintWriter(predictFile);
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(5);
        for (Datum datum:dataset) {
          Counter scores = classifier.scoresOf(datum);
          pw2.println(Counters.toString(scores, nf));
        }
        pw2.close();
      }

      if (useSigmoid) {
        if (verbose) System.out.print("fitting sigmoid...");
        classifier.setPlatt(fitSigmoid(classifier, dataset));
        if (verbose) System.out.println("done");
      }

      return classifier;
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy