
cc.mallet.classify.ConfidencePredictingClassifier Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Andrew McCallum [email protected]
*/
package cc.mallet.classify;
import java.util.ArrayList;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.Label;
import cc.mallet.types.LabelVector;
public class ConfidencePredictingClassifier extends Classifier
{
Classifier underlyingClassifier;
Classifier confidencePredictingClassifier;
double totalCorrect;
double totalIncorrect;
double totalIncorrectIncorrect;
double totalIncorrectCorrect;
int numCorrectInstances;
int numIncorrectInstances;
int numConfidenceCorrect;
int numFalsePositive;
int numFalseNegative;
public ConfidencePredictingClassifier (Classifier underlyingClassifier, Classifier confidencePredictingClassifier)
{
super (underlyingClassifier.getInstancePipe());
this.underlyingClassifier = underlyingClassifier;
this.confidencePredictingClassifier = confidencePredictingClassifier;
// for testing confidence accuracy
totalCorrect = 0.0;
totalIncorrect = 0.0;
totalIncorrectIncorrect = 0.0;
totalIncorrectCorrect = 0.0;
numCorrectInstances = 0;
numIncorrectInstances = 0;
numConfidenceCorrect = 0;
numFalsePositive = 0;
numFalseNegative = 0;
}
public Classification classify (Instance instance)
{
Classification c = underlyingClassifier.classify (instance);
Classification cpc = confidencePredictingClassifier.classify (c);
LabelVector lv = c.getLabelVector();
int bestIndex = lv.getBestIndex();
double [] values = new double[lv.numLocations()];
//// Put score of "correct" into score of the winning class...
// xxx Can't set lv - it's immutable.
// Must create copy and new classification object
// lv.set (bestIndex, cpc.getLabelVector().value("correct"));
//for (int i = 0; i < lv.numLocations(); i++)
// if (i != bestIndex)
// lv.set (i, 0.0);
// Put score of "correct" in winning class and
// set rest to 0
for (int i = 0; i < lv.numLocations(); i++) {
if (i != bestIndex)
values[i] = 0.0;
else values[i] = cpc.getLabelVector().value("correct");
}
//return c;
if(c.bestLabelIsCorrect()){
numCorrectInstances++;
totalCorrect+=cpc.getLabelVector().value("correct");
totalIncorrectCorrect+=cpc.getLabelVector().value("incorrect");
String correct = new String("correct");
if(correct.equals(cpc.getLabelVector().getBestLabel().toString()))
numConfidenceCorrect++;
else numFalseNegative++;
}
else{
numIncorrectInstances++;
totalIncorrect+=cpc.getLabelVector().value("correct");
totalIncorrectIncorrect+=cpc.getLabelVector().value("incorrect");
if((new String("incorrect")).equals(cpc.getLabelVector().getBestLabel().toString()))
numConfidenceCorrect++;
else numFalsePositive++;
}
return new Classification(instance, this, new LabelVector(lv.getLabelAlphabet(), values));
// return cpc;
}
public void printAverageScores() {
System.out.println("Mean score of correct for correct instances = " + meanCorrect());
System.out.println("Mean score of correct for incorrect instances = " + meanIncorrect());
System.out.println("Mean score of incorrect for correct instances = " +
this.totalIncorrectCorrect/this.numCorrectInstances);
System.out.println("Mean score of incorrect for incorrect instances = " +
this.totalIncorrectIncorrect/this.numIncorrectInstances);
}
public void printConfidenceAccuracy() {
System.out.println("Confidence predicting accuracy = " +
((double)numConfidenceCorrect/(numIncorrectInstances + numCorrectInstances))+ " false negatives: "+ numFalseNegative + "/"+numCorrectInstances + " false positives: "+ numFalsePositive +" / " +numIncorrectInstances);
}
public double meanCorrect()
{
if(this.numCorrectInstances==0)
return 0.0;
return (this.totalCorrect/(double)this.numCorrectInstances);
}
public double meanIncorrect()
{
if(this.numIncorrectInstances==0)
return 0.0;
return (this.totalIncorrect/(double)this.numIncorrectInstances);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy