All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.expleague.ml.loss.multiclass.util.ConfusionMatrix Maven / Gradle / Ivy
package com.expleague.ml.loss.multiclass.util;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.util.table.TableBuilder;
import com.expleague.ml.data.tools.MCTools;
import com.expleague.commons.util.ArrayTools;
/**
* User: amosov-f
* User: qdeee
* Date: 13.08.14
* Time: 11:19
*/
public class ConfusionMatrix {
private final int[][] counts;
private final int[] fp;
private final int[] fn;
public ConfusionMatrix(final int numClasses) {
counts = new int[numClasses][numClasses];
fp = new int[numClasses];
fn = new int[numClasses];
}
//TODO[qdeee]: add validation for bad target(same values, missed class)
public ConfusionMatrix(final IntSeq target, final IntSeq predicted) {
this(MCTools.countClasses(target));
for (int i = 0; i < target.length(); i++) {
final int expected = target.arr[i];
final int actual = predicted.arr[i];
if (expected != actual) {
fn[expected]++;
if (actual == counts.length)
//error class label. we should update false negatives and skip others updates
continue;
fp[actual]++;
}
counts[expected][actual]++;
}
}
public void add(final int expected, final int actual) {
counts[expected][actual]++;
if (expected != actual) {
fn[expected]++;
if (actual != -1)
fp[actual]++;
}
}
public double getPrecision(final int c) {
return (tp(c) + fp(c) > 0) ? tp(c) / (tp(c) + fp(c) + 0.) : 0;
}
public int getCount(final int c) {
return ArrayTools.sum(counts[c]);
}
public double getRecall(final int c) {
return (tp(c) + fn(c) > 0) ? tp(c) / (tp(c) + fn(c) + 0.) : 0;
}
public double getF1Measure(final int clazz) {
final double p = getPrecision(clazz);
final double r = getRecall(clazz);
if (p + r == 0) {
return 0;
}
return 2 * p * r / (p + r);
}
public double getMacroPrecision() {
double macroPrecision = 0;
for (int i = 0; i < counts.length; i++) {
macroPrecision += getPrecision(i);
}
macroPrecision /= counts.length;
return macroPrecision;
}
public double getMacroRecall() {
double macroRecall = 0;
for (int i = 0; i < counts.length; i++) {
macroRecall += getRecall(i);
}
macroRecall /= counts.length;
return macroRecall;
}
public double getMacroF1Measure() {
final double p = getMacroPrecision();
final double r = getMacroRecall();
if (p + r == 0) {
return 0;
}
return 2 * p * r / (p + r);
}
public double getMicroPrecision() {
int tps = 0;
int fps = 0;
for (int i = 0; i < counts.length; i++) {
tps += tp(i);
fps += fp(i);
}
return (tps + fps > 0) ? tps / (tps + fps + 0.) : 0;
}
public double getMicroRecall() {
int tps = 0;
int fns = 0;
for (int i = 0; i < counts.length; i++) {
tps += tp(i);
fns += fn(i);
}
return (tps + fns > 0) ? tps / (tps + fns + 0.) : 0;
}
public double getMicroF1Measure() {
final double p = getMicroPrecision();
final double r = getMicroRecall();
if (p + r == 0) {
return 0;
}
return 2 * p * r / (p + r);
}
public double getCohenKappa() {
final int[] sumRows = new int[counts.length];
final int[] sumColumns = new int[counts.length];
int sumOfWeights = 0;
for (int i = 0; i < counts.length; i++) {
for (int j = 0; j < counts.length; j++) {
sumRows[i] += counts[i][j];
sumColumns[j] += counts[i][j];
sumOfWeights += counts[i][j];
}
}
double correct = 0;
double chanceAgreement = 0;
for (int i = 0; i < counts.length; i++) {
chanceAgreement += (sumRows[i] * sumColumns[i]);
correct += counts[i][i];
}
chanceAgreement /= (sumOfWeights * sumOfWeights);
correct /= sumOfWeights;
if (chanceAgreement < 1) {
return (correct - chanceAgreement) / (1 - chanceAgreement);
} else {
return 1;
}
}
public int tp(final int clazz) {
return counts[clazz][clazz];
}
public int fp(final int clazz) {
return fp[clazz];
}
public int fn(final int clazz) {
return fn[clazz];
}
public void add(final ConfusionMatrix confusionMatrix) {
for (int i = 0; i < counts.length; i++) {
for (int j = 0; j < counts[i].length; j++) {
counts[i][j] += confusionMatrix.counts[i][j];
}
fp[i] += confusionMatrix.fp[i];
fn[i] += confusionMatrix.fn[i];
}
}
public int getNumExamples() {
int count = 0;
for (int i = 0; i < counts.length; i++) {
count += ArrayTools.sum(counts[i]);
}
return count;
}
public int getNumExamples(final int clazz) {
return ArrayTools.sum(counts[clazz]);
}
public int getNumClasses() {
return counts.length;
}
public String toSummaryString() {
final TableBuilder tableBuilder = new TableBuilder("Metric", "Value");
final String result = tableBuilder
.addRow("Micro precision: ", getMicroPrecision())
.addRow("Micro recall: ", getMicroRecall())
.addRow("Micro F1-measure: ", getMicroF1Measure())
.addRow("Macro precision: ", getMacroPrecision())
.addRow("Macro recall: ", getMacroRecall())
.addRow("Macro F1-measure: ", getMacroF1Measure())
.build();
return "=== Summary ===\n" + result;
}
public String toClassDetailsString() {
final TableBuilder tableBuilder = new TableBuilder("class", "count", "precision", "recall", "f1-measure");
for (int i = 0; i < counts.length; i++) {
tableBuilder.addRow(String.valueOf(i), getCount(i), getPrecision(i), getRecall(i), getF1Measure(i));
}
return "=== Detailed Accuracy By Class ===\n" + tableBuilder.build();
}
public String oneLineReport() {
final String f = "%s = %.6f,\t";
return "{" +
String.format(f, "mP", getMicroPrecision()) +
String.format(f, "mR", getMicroRecall()) +
String.format(f, "mF", getMicroF1Measure()) +
String.format(f, "MP", getMacroPrecision()) +
String.format(f, "MR", getMacroRecall()) +
String.format(f, "MF", getMacroF1Measure()) + "}";
}
public String debug() {
return String.format("%.6f|%.6f|%.6f|%.6f|%.6f|%.6f|",
getMicroPrecision(), getMicroRecall(), getMicroF1Measure(), getMacroPrecision(), getMacroRecall(), getMacroF1Measure());
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("=== Confusion Matrix ===\n");
for (int i = 0; i < counts.length; i++) {
for (int j = 0; j < counts.length; j++) {
sb.append(String.format("%6d\t", counts[i][j]));
}
sb.append("\n");
}
return sb.toString();
}
}