All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.loss.multiclass.util.MultilabelConfusionMatrix Maven / Gradle / Ivy

package com.expleague.ml.loss.multiclass.util;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.util.table.TableBuilder;
import com.expleague.commons.util.ArrayTools;

/**
 * User: qdeee
 * Date: 24.03.15
 */
public class MultilabelConfusionMatrix {
  private ConfusionMatrix[] matrixes;

  public MultilabelConfusionMatrix(final Mx targets, final Mx predicted) {
    matrixes = new ConfusionMatrix[targets.columns()];
    for (int j = 0; j < matrixes.length; j++) {
      matrixes[j] = new ConfusionMatrix(
          VecTools.toIntSeq(targets.col(j)),
          VecTools.toIntSeq(predicted.col(j))
      );
    }
  }

  public String toSummaryString() {
    final double[] microPrecision = new double[matrixes.length];
    final double[] macroPrecision = new double[matrixes.length];
    final double[] macroRecall = new double[matrixes.length];
    final double[] macroFScore = new double[matrixes.length];

    for (int i = 0; i < matrixes.length; i++) {
      microPrecision[i] = matrixes[i].getMicroPrecision();
      macroPrecision[i] = matrixes[i].getMacroPrecision();
      macroRecall[i] = matrixes[i].getMacroRecall();
      macroFScore[i] = matrixes[i].getMacroF1Measure();
    }

    final TableBuilder tableBuilder = new TableBuilder();
    tableBuilder.setHeader("Metric\\Label", ArrayTools.sequence(0, matrixes.length));
//    tableBuilder.addRow("Micro precision: ",  microPrecision);
//    tableBuilder.addRow("Micro recall: ",     microPrecision);
//    tableBuilder.addRow("Micro F1-measure: ", microPrecision);
    tableBuilder.addRow("Macro precision: ",  macroPrecision);
    tableBuilder.addRow("Macro recall: ",     macroRecall);
    tableBuilder.addRow("Macro F1-measure: ", macroFScore);
    return "=== Summary ===\n" + tableBuilder.build();
  }

  public String toClassDetailsString() {
    final double[] precision0 = new double[matrixes.length];
    final double[] precision1 = new double[matrixes.length];
    final double[] recall0 = new double[matrixes.length];
    final double[] recall1 = new double[matrixes.length];
    final double[] fScore0 = new double[matrixes.length];
    final double[] fScore1 = new double[matrixes.length];

    for (int i = 0; i < matrixes.length; i++) {
      precision0[i] = matrixes[i].getPrecision(0);
      precision1[i] = matrixes[i].getPrecision(1);
      recall0[i] = matrixes[i].getRecall(0);
      recall1[i] = matrixes[i].getRecall(1);
      fScore0[i] = matrixes[i].getF1Measure(0);
      fScore1[i] = matrixes[i].getF1Measure(1);
    }

    final TableBuilder tableBuilder = new TableBuilder();
    final String table = tableBuilder
        .setHeader("metric\\label", ArrayTools.sequence(0, matrixes.length))
        .addRow("[0] precision", precision0)
        .addRow("[0] recall",    recall0)
        .addRow("[0] f1-score",  fScore0)
        .addRow("[1] precision", precision1)
        .addRow("[1] recall",    recall1)
        .addRow("[1] f1-score",  fScore1)
        .build();
    return "=== Detailed Accuracy By Class ===\n" + table;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy