
cc.mallet.classify.PRAuxClassifier Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.util.ArrayList;
import cc.mallet.classify.constraints.pr.MaxEntPRConstraint;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
/**
* Auxiliary model (q) for E-step/I-projection in PR training.
*
* @author Gregory Druck [email protected]
*/
public class PRAuxClassifier extends Classifier {
private static final long serialVersionUID = 1L;
private int numLabels;
private double[][] parameters;
private ArrayList constraints;
public PRAuxClassifier(Pipe pipe, ArrayList constraints) {
super(pipe);
this.constraints = constraints;
this.parameters = new double[constraints.size()][];
for (int i = 0; i < constraints.size(); i++) {
this.parameters[i] = new double[constraints.get(i).numDimensions()];
}
this.numLabels = pipe.getTargetAlphabet().size();
}
public void getClassificationScores(Instance instance, double[] scores) {
FeatureVector input = (FeatureVector)instance.getData();
for (MaxEntPRConstraint feature : constraints) {
feature.preProcess(input);
}
for (int li = 0; li < numLabels; li++) {
int ci = 0;
for (MaxEntPRConstraint feature : constraints) {
scores[li] += feature.getScore(input, li, parameters[ci]);
ci++;
}
}
}
public void getClassificationProbs(Instance instance, double[] scores) {
getClassificationScores(instance,scores);
MatrixOps.expNormalize(scores);
}
@Override
public Classification classify(Instance instance) {
double[] scores = new double[numLabels];
getClassificationScores(instance,scores);
return new Classification (instance, this, new LabelVector (getLabelAlphabet(), scores));
}
public double[][] getParameters() {
return parameters;
}
public ArrayList getConstraintFeatures() {
return constraints;
}
public void zeroExpectations() {
for (MaxEntPRConstraint constraint : constraints) {
constraint.zeroExpectations();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy