com.expleague.ml.loss.multiclass.util.MultilabelThresholdPrecisionMatrix Maven / Gradle / Ivy
package com.expleague.ml.loss.multiclass.util;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.util.table.TableBuilder;
/**
* Created by irlab on 25.06.2015.
*/
public class MultilabelThresholdPrecisionMatrix {
private final Mx scores;
private final Mx targets;
private final int probThresholdBuckets;
private final String name;
public MultilabelThresholdPrecisionMatrix(final Mx scores, final Mx targets, final int probThresholdBuckets, final String name) {
this.scores = scores;
this.targets = targets;
this.probThresholdBuckets = probThresholdBuckets;
this.name = name;
}
public String toThresholdPrecisionMatrix() {
final TableBuilder tableBuilder = new TableBuilder();
final String[] header = new String[targets.columns() * 2];
for (int i = 0; i < targets.columns(); i++) {
header[i * 2] = "class " + i + " precision";
header[i * 2 + 1] = "class " + i + " recall";
}
tableBuilder.setHeader("threshold", header);
for (int thresholdNum = 0; thresholdNum <= probThresholdBuckets; thresholdNum++) {
final double threshold = ((double)thresholdNum) / probThresholdBuckets;
final double[] tableRow = new double[targets.columns() * 2];
for (int classId = 0; classId < targets.columns(); classId++) {
int cntf = 0;
int cntr = 0;
int cntfr = 0;
for (int example = 0; example < targets.rows(); example++) {
final double prob = 1./ (1. + Math.exp(-scores.get(example, classId)));
final boolean isFound = targets.get(example, classId) > 0;
if (prob >= threshold) {
cntf ++;
if (isFound)
cntfr ++;
}
if (isFound)
cntr++;
}
final double precision = cntfr / (cntf + 1E-12);
final double recall = cntfr / (cntr + 1E-12);
tableRow[classId * 2] = precision;
tableRow[classId * 2 + 1] = recall;
}
tableBuilder.addRow("" + threshold, tableRow);
}
final String table = tableBuilder.build();
return name + "\n" + table;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy