com.github.chen0040.data.evaluators.ClassifierEvaluator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of java-data-frame Show documentation
Show all versions of java-data-frame Show documentation
Some common patterns of data frame in Java
The newest version!
package com.github.chen0040.data.evaluators;
import com.github.chen0040.data.utils.NumberUtils;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Created by xschen on 11/16/16.
*/
public class ClassifierEvaluator implements Serializable {
private static final long serialVersionUID = -6691826271325237852L;
private ConfusionMatrix confusionMatrix = new ConfusionMatrix();
public void evaluate(String actual, String predicted){
confusionMatrix.incCount(actual, predicted);
}
public List classLabels() {
return confusionMatrix.getLabels();
}
public void reset() {
confusionMatrix.reset();
}
public ConfusionMatrix getConfusionMatrix() {
return confusionMatrix;
}
public void setConfusionMatrix(ConfusionMatrix confusionMatrix) {
this.confusionMatrix = confusionMatrix;
}
public double getAccuracy() {
double accuracy = 0;
List list = confusionMatrix.getLabels();
int correctCount = 0;
int totalCount = 0;
for(int i=0; i < list.size(); ++i) {
String actual = list.get(i);
for(int j=0; j < list.size(); ++j) {
String predicted = list.get(j);
int value = confusionMatrix.getCount(actual, predicted);
correctCount += (i == j) ? value : 0;
totalCount += value;
}
}
if(totalCount > 0) {
accuracy = (double) correctCount / totalCount;
}
return accuracy;
}
public double getMisclassificationRate(){
return 1- getAccuracy();
}
public int getTruePositiveCount(String classLabel) {
return confusionMatrix.getCount(classLabel, classLabel);
}
public int getFalsePositiveCount(String classLabel) {
return confusionMatrix.getColumnSum(classLabel) - getTruePositiveCount(classLabel);
}
public double avgTruePositive() {
List labels = classLabels();
if(labels.isEmpty()) return 0;
int sum = 0;
for(String label : labels) {
sum += getTruePositiveCount(label);
}
return (double)sum / labels.size();
}
public double avgFalsePositive() {
List labels = classLabels();
if(labels.isEmpty()) return 0;
int sum = 0;
for(String label : labels) {
sum += getFalsePositiveCount(label);
}
return (double)sum / labels.size();
}
// Precision is the proportion of cases correctly identified as belonging to class c
// among all cases of which the classifier claims that they belong to class c
public Map getPrecisionByClass() {
Map result = new HashMap<>();
List list = classLabels();
for(int i=0; i < list.size(); ++i) {
String label = list.get(i);
int correctCount = confusionMatrix.getCount(label, label);
int totalPredictedCount = confusionMatrix.getColumnSum(label);
double precision = 0;
if(totalPredictedCount > 0){
precision = (double)correctCount / totalPredictedCount;
}
result.put(label, precision);
}
return result;
}
// Recall is the proportion of cases correctly identified as belonging to class c among all
// cases that truely belong to class c.
public Map getRecallByClass(){
Map result = new HashMap<>();
List list = classLabels();
for(int i=0; i < list.size(); ++i) {
String label = list.get(i);
int correctCount = confusionMatrix.getCount(label, label);
int totalTrueCount = confusionMatrix.getRowSum(label);
double recall = 0;
if(totalTrueCount > 0) {
recall = (double)correctCount / totalTrueCount;
}
result.put(label,recall);
}
return result;
}
// fallout is the proportion of cases incorrectly identified as belonging to class c among all
// cases that truely not belonging to class c.
// fallout is the false-positive rate.
public Map getFalloutByClass(){
Map result = new HashMap<>();
List list = classLabels();
for(int i=0; i < list.size(); ++i) {
String label = list.get(i);
int totalNegativeCount = 0;
int falsePositiveCount = 0;
for(int j=0; j < list.size(); ++j) {
if(i==j) continue;
String notTrueLabel = list.get(j);
falsePositiveCount += confusionMatrix.getCount(notTrueLabel, label);
totalNegativeCount += confusionMatrix.getRowSum(notTrueLabel);
}
double fallout = 0;
if(totalNegativeCount > 0) {
fallout = (double)falsePositiveCount / totalNegativeCount;
}
result.put(label,fallout);
}
return result;
}
public Map getF1ScoreByClass() {
Map precisions = getPrecisionByClass();
Map recalls = getRecallByClass();
List labels = classLabels();
Map result = new HashMap<>();
for(String label : labels) {
double precision = precisions.get(label);
double recall = recalls.get(label);
if(NumberUtils.isZero(precision+recall)){
continue;
}
double f1score = 2 * (precision * recall) / (precision + recall);
result.put(label, f1score);
}
return result;
}
// concept of macro-f1 score can be found here: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf
public double getMacroF1Score() {
double sum = 0;
int count = 0;
Map data = getF1ScoreByClass();
for(Map.Entry entry : data.entrySet()) {
sum += entry.getValue();
count++;
}
if(count == 0) return 0;
return sum / count;
}
// concept of micro-f1 score can be found here: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf
public double getMicroF1Score() {
Map precisions = getPrecisionByClass();
Map recalls = getRecallByClass();
List labels = classLabels();
double precisionAvg = 0;
double recallAvg = 0;
for(String label : labels) {
double precision = precisions.get(label);
double recall = recalls.get(label);
precisionAvg += precision;
recallAvg += recall;
}
precisionAvg /= labels.size();
recallAvg /= labels.size();
return 2 * (precisionAvg * recallAvg) / (precisionAvg + recallAvg);
}
public String getSummary() {
StringBuilder sb = new StringBuilder();
sb.append("accuracy: ").append(getAccuracy());
sb.append("\nmis-classification: ").append(getMisclassificationRate());
sb.append("\nmacro f1-score: ").append(getMacroF1Score());
sb.append("\nmicro f1-score: ").append(getMicroF1Score());
return sb.toString();
}
public void report(){
System.out.println(getSummary());
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy