edu.stanford.nlp.classify.AbstractLinearClassifierFactory Maven / Gradle / Ivy
package edu.stanford.nlp.classify;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.HashIndex;
import java.lang.ref.Reference;
import java.util.Collection;
import java.util.List;
/**
* Shared methods for training a {@link LinearClassifier}.
* Inheriting classes need to implement the
* trainWeights
method.
*
* @author Dan Klein
*
* @author Sarah Spikes ([email protected]) (Templatization)
*
* @param The type of the labels in the Dataset and Datum
* @param The type of the features in the Dataset and Datum
*/
public abstract class AbstractLinearClassifierFactory implements ClassifierFactory> {
private static final long serialVersionUID = 1L;
Index labelIndex = new HashIndex();
Index featureIndex = new HashIndex();
public AbstractLinearClassifierFactory() {
}
int numFeatures() {
return featureIndex.size();
}
int numClasses() {
return labelIndex.size();
}
public Classifier trainClassifier(List> examples) {
Dataset dataset = new Dataset();
dataset.addAll(examples);
return trainClassifier(dataset);
}
protected abstract double[][] trainWeights(GeneralDataset dataset) ;
/**
* Takes a {@link Collection} of {@link Datum} objects and gives you back a
* {@link Classifier} trained on it.
*
* @param examples {@link Collection} of {@link Datum} objects to train the
* classifier on
* @return A {@link Classifier} trained on it.
*/
public LinearClassifier trainClassifier(Collection> examples) {
Dataset dataset = new Dataset();
dataset.addAll(examples);
return trainClassifier(dataset);
}
/**
* Takes a {@link Reference} to a {@link Collection} of {@link Datum}
* objects and gives you back a {@link Classifier} trained on them
*
* @param ref {@link Reference} to a {@link Collection} of {@link
* Datum} objects to train the classifier on
* @return A Classifier trained on a collection of Datum
*/
public LinearClassifier trainClassifier(Reference>> ref) {
Collection> examples = ref.get();
return trainClassifier(examples);
}
/**
* Trains a {@link Classifier} on a {@link Dataset}.
*
* @return A {@link Classifier} trained on the data.
*/
public LinearClassifier trainClassifier(GeneralDataset data) {
labelIndex = data.labelIndex();
featureIndex = data.featureIndex();
double[][] weights = trainWeights(data);
return new LinearClassifier(weights, featureIndex, labelIndex);
}
}