All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multiclass.MultiClassOneVsRestSeq Maven / Gradle / Ivy

package com.expleague.ml.methods.multiclass;

import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.data.tools.MCTools;
import com.expleague.ml.loss.LLLogit;
import com.expleague.ml.loss.multiclass.ClassicMulticlassLoss;
import com.expleague.ml.methods.SeqOptimization;
import com.expleague.ml.models.multiclass.JoinedBinClassModelSeq;

import java.util.function.Function;

public class MultiClassOneVsRestSeq implements SeqOptimization {
  private final SeqOptimization learner;

  public MultiClassOneVsRestSeq(final SeqOptimization learner) {
    this.learner = learner;
  }

  @Override
  public Function, Vec> fit(DataSet> learn,
                                   final ClassicMulticlassLoss multiclassLoss) {
    final IntSeq labels = multiclassLoss.labels();
    final int countClasses = MCTools.countClasses(labels);

    //noinspection unchecked
    final Function, Vec>[] models = new Function[countClasses];
    for (int c = 0; c < countClasses; c++) {
      final Vec oneVsRestTarget = MCTools.extractClassForBinary(labels, c);
      final LLLogit llLogit = new LLLogit(oneVsRestTarget, learn.parent());
      final Function, Vec> model = learner.fit(learn, llLogit);

      models[c] = model;
    }
    return new JoinedBinClassModelSeq<>(models);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy