All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.cli.output.printers.ResultsPrinter Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.cli.output.printers;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.data.tools.MCTools;
import com.expleague.ml.data.tools.MultiLabelTools;
import com.expleague.ml.data.tools.Pool;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.func.FuncJoin;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.loss.multiclass.util.ConfusionMatrix;
import com.expleague.ml.loss.multiclass.util.MultilabelConfusionMatrix;
import com.expleague.ml.loss.multiclass.util.MultilabelExampleTableOutput;
import com.expleague.ml.loss.multiclass.util.MultilabelThresholdPrecisionMatrix;
import com.expleague.ml.loss.multilabel.MultiLabelExactMatch;
import com.expleague.ml.models.MultiClassModel;
import com.expleague.ml.models.multiclass.MCModel;
import com.expleague.ml.models.multilabel.MultiLabelBinarizedModel;
import com.expleague.ml.models.multilabel.MultiLabelModel;

import java.util.function.Function;

/**
 * User: qdeee
 * Date: 04.09.14
 */
@SuppressWarnings("UseOfSystemOutOrSystemErr")
public class ResultsPrinter {
  public static void printResults(final Function computable, final Pool learn, final Pool test, final Func loss, final Func[] metrics) {
    System.out.print("Learn: " + loss.value(DataTools.calcAll(computable, learn.vecData())) + " Test:");
    for (final Func metric : metrics) {
      System.out.print(" " + metric.value(DataTools.calcAll(computable, test.vecData())));
    }
    System.out.println();
  }

  public static void printMulticlassResults(final Function function, final Pool learn, final Pool test) {
    final MCModel mcModel;
    if (function instanceof Ensemble && ((Ensemble) function).last() instanceof FuncJoin) {
      final FuncJoin funcJoin = MCTools.joinBoostingResult((Ensemble) function);
      mcModel = new MultiClassModel(funcJoin);
    } else if (function instanceof MCModel) {
      mcModel = (MCModel) function;
    } else return;

    final IntSeq learnTarget = learn.target(BlockwiseMLLLogit.class).labels();
    final Vec learnPredict = mcModel.bestClassAll(learn.vecData().data());
    final ConfusionMatrix learnConfusionMatrix = new ConfusionMatrix(learnTarget, VecTools.toIntSeq(learnPredict));
    System.out.println("LEARN:");
    System.out.println(learnConfusionMatrix.toSummaryString());
    System.out.println(learnConfusionMatrix.toClassDetailsString());

    final IntSeq testTarget = test.target(BlockwiseMLLLogit.class).labels();
    final Vec testPredict = mcModel.bestClassAll(test.vecData().data());
    final ConfusionMatrix testConfusionMatrix = new ConfusionMatrix(testTarget, VecTools.toIntSeq(testPredict));
    System.out.println("TEST:");
    System.out.println(testConfusionMatrix.toSummaryString());
    System.out.println(testConfusionMatrix.toClassDetailsString());
    System.out.println();
  }

  public static void printMultilabelResult(final Function function, final Pool learn, final Pool test) {
    final MultiLabelModel mlModel = MultiLabelTools.extractMultiLabelModel((Trans) function);

    final Mx learnTargets = learn.multiTarget(MultiLabelExactMatch.class).getTargets();
    final Mx learnPredicted = mlModel.predictLabelsAll(learn.vecData().data());
    final MultilabelConfusionMatrix learnConfusionMatrix = new MultilabelConfusionMatrix(learnTargets, learnPredicted);
    System.out.println("[LEARN]");
    System.out.println(learnConfusionMatrix.toSummaryString());
    System.out.println(learnConfusionMatrix.toClassDetailsString());

    final Mx testTargets = test.multiTarget(MultiLabelExactMatch.class).getTargets();
    final Mx testPredicted = mlModel.predictLabelsAll(test.vecData().data());
    final MultilabelConfusionMatrix testConfusionMatrix = new MultilabelConfusionMatrix(testTargets, testPredicted);
    System.out.println("[TEST]");
    System.out.println(testConfusionMatrix.toSummaryString());
    System.out.println(testConfusionMatrix.toClassDetailsString());

    if (mlModel instanceof MultiLabelBinarizedModel) {
      final Trans model = ((MultiLabelBinarizedModel)mlModel).getInternModel();
      final Mx testScores = model.transAll(test.vecData().data());
      System.out.println(new MultilabelThresholdPrecisionMatrix(testScores, testTargets, 100, "=== Precision/recall curve on TEST ===").toThresholdPrecisionMatrix());

      final Mx learnScores = model.transAll(learn.vecData().data());
      System.out.println(new MultilabelThresholdPrecisionMatrix(learnScores, learnTargets, 100, "=== Precision/recall curve on LEARN ===").toThresholdPrecisionMatrix());

      System.out.println(new MultilabelExampleTableOutput(testScores, testTargets, test, "=== Scores for examples on TEST ===\n").toExampleTableMatrix());

      System.out.println(new MultilabelExampleTableOutput(learnScores, learnTargets, learn, "=== Scores for examples on LEARN ===\n").toExampleTableMatrix());
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy