All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.classify.AbstractLinearClassifierFactory Maven / Gradle / Ivy

package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.HashIndex;

import java.lang.ref.Reference;
import java.util.Collection;
import java.util.List;

/**
 * Shared methods for training a {@link LinearClassifier}.
 * Inheriting classes need to implement the
 * trainWeights method.
 *
 * @author Dan Klein
 *
 * @author Sarah Spikes ([email protected]) (Templatization)
 *
 * @param  The type of the labels in the Dataset and Datum
 * @param  The type of the features in the Dataset and Datum
 */

public abstract class AbstractLinearClassifierFactory implements ClassifierFactory> {

  private static final long serialVersionUID = 1L;

  Index labelIndex = new HashIndex();
  Index featureIndex = new HashIndex();

  public AbstractLinearClassifierFactory() {
  }

  int numFeatures() {
    return featureIndex.size();
  }

  int numClasses() {
    return labelIndex.size();
  }

  public Classifier trainClassifier(List> examples) {
    Dataset dataset = new Dataset();
    dataset.addAll(examples);
    return trainClassifier(dataset);
  }

  protected abstract double[][] trainWeights(GeneralDataset dataset) ;

  /**
   * Takes a {@link Collection} of {@link Datum} objects and gives you back a
   * {@link Classifier} trained on it.
   *
   * @param examples {@link Collection} of {@link Datum} objects to train the
   *                 classifier on
   * @return A {@link Classifier} trained on it.
   */
  public LinearClassifier trainClassifier(Collection> examples) {
    Dataset dataset = new Dataset();
    dataset.addAll(examples);
    return trainClassifier(dataset);
  }

  /**
   * Takes a {@link Reference} to a {@link Collection} of {@link Datum}
   * objects and gives you back a {@link Classifier} trained on them
   *
   * @param ref {@link Reference} to a {@link Collection} of {@link
   *            Datum} objects to train the classifier on
   * @return A Classifier trained on a collection of Datum
   */
  public LinearClassifier trainClassifier(Reference>> ref) {
    Collection> examples = ref.get();
    return trainClassifier(examples);
  }


  /**
   * Trains a {@link Classifier} on a {@link Dataset}.
   *
   * @return A {@link Classifier} trained on the data.
   */
  public LinearClassifier trainClassifier(GeneralDataset data) {
    labelIndex = data.labelIndex();
    featureIndex = data.featureIndex();
    double[][] weights = trainWeights(data);
    return new LinearClassifier(weights, featureIndex, labelIndex);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy