All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.cli.output.printers.MultiLabelLogitProgressPrinter Maven / Gradle / Ivy

package com.expleague.ml.cli.output.printers;

import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.tools.Pool;
import com.expleague.ml.func.FuncJoin;
import com.expleague.ml.loss.blockwise.BlockwiseMultiLabelLogit;
import com.expleague.ml.loss.multilabel.MultiLabelMicroFScore;
import com.expleague.ml.models.multiclass.MCModel;
import com.expleague.ml.models.multilabel.MultiLabelModel;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.ml.ProgressHandler;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.loss.multilabel.MultiLabelExactMatch;
import com.expleague.ml.loss.multilabel.MultiLabelHammingLoss;
import com.expleague.ml.loss.multilabel.MultiLabelMacroFScore;

import java.util.ArrayList;
import java.util.List;

import static com.expleague.commons.math.vectors.VecTools.append;
import static com.expleague.commons.math.vectors.VecTools.scale;

/**
 * User: qdeee
 * Date: 03.04.15
 */
public class MultiLabelLogitProgressPrinter implements ProgressHandler {
  private final VecDataSet learn;
  private final VecDataSet test;

  private final BlockwiseMultiLabelLogit learnLogit;
  private final BlockwiseMultiLabelLogit testLogit;
  private final Mx learnValues;
  private final Mx testValues;

  private final List learnMetrics = new ArrayList<>();
  private final List testMetrics = new ArrayList<>();

  private final int itersForOut;
  private int iteration = 0;

  public MultiLabelLogitProgressPrinter(final Pool learn, final Pool test) {
    this(learn, test, 10);
  }

  public MultiLabelLogitProgressPrinter(final Pool learn, final Pool test, final int itersForOut) {
    this.learn = learn.vecData();
    this.test = test.vecData();

    this.learnLogit = learn.target(BlockwiseMultiLabelLogit.class);
    this.testLogit = test.target(BlockwiseMultiLabelLogit.class);
    this.learnValues = new VecBasedMx(learn.size(), learnLogit.blockSize());
    this.testValues = new VecBasedMx(test.size(), testLogit.blockSize());

    this.learnMetrics.add(learn.target(MultiLabelExactMatch.class));
    this.learnMetrics.add(learn.target(MultiLabelMicroFScore.class));
    this.learnMetrics.add(learn.target(MultiLabelMacroFScore.class));
    this.learnMetrics.add(learn.target(MultiLabelHammingLoss.class));

    this.testMetrics.add(test.target(MultiLabelExactMatch.class));
    this.testMetrics.add(test.target(MultiLabelMicroFScore.class));
    this.testMetrics.add(test.target(MultiLabelMacroFScore.class));
    this.testMetrics.add(test.target(MultiLabelHammingLoss.class));

    this.itersForOut = itersForOut;
  }

  @Override
  public void accept(final Trans partial) {
    if (isBoostingProcess(partial)) {
      final Ensemble ensemble = (Ensemble) partial;
      final double step = ensemble.wlast();
      final FuncJoin model = (FuncJoin) ensemble.last();

      //caching boosting results
      append(learnValues, scale(model.transAll(learn.data()), step));
      append(testValues, scale(model.transAll(test.data()), step));
    }

    iteration++;
    if (iteration % itersForOut == 0) {
      final Mx learnPredicted;
      final Mx testPredicted;

      if (isBoostingProcess(partial)) {
        learnPredicted = VecTools.toBinary(VecTools.copy(learnValues));
        testPredicted = VecTools.toBinary(VecTools.copy(testValues));

      } else if (partial instanceof MCModel) {
        final MultiLabelModel mcModel = (MultiLabelModel) partial;
        learnPredicted = mcModel.predictLabelsAll(learn.data());
        testPredicted = mcModel.predictLabelsAll(test.data());
      } else return;

      System.out.print(iteration);
      System.out.print(" " + learnLogit.value(learnValues));
      System.out.print(" " + testLogit.value(testValues));
      System.out.print(" { ");
      for (Func learnMetric : learnMetrics) {
        System.out.print(learnMetric.value(learnPredicted));
        System.out.print(" ");
      }
      System.out.printf("} {");
      for (Func testMetric : testMetrics) {
        System.out.print(testMetric.value(testPredicted));
        System.out.print(" ");
      }
      System.out.printf("}\n");
    }
  }

  private static boolean isBoostingProcess(final Trans partial) {
    return partial instanceof Ensemble && ((Ensemble) partial).last() instanceof FuncJoin;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy