All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multiclass.MultiClassOneVsRest Maven / Gradle / Ivy

package com.expleague.ml.methods.multiclass;

import com.expleague.commons.func.Computable;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.tools.MCTools;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.func.FuncEnsemble;
import com.expleague.ml.loss.LLLogit;
import com.expleague.ml.loss.multiclass.ClassicMulticlassLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.multiclass.JoinedBinClassModel;

/**
 * User: qdeee
 * Date: 22.01.15
 */
public class MultiClassOneVsRest implements VecOptimization {
  private final VecOptimization learner;

  public MultiClassOneVsRest(final VecOptimization learner) {
    this.learner = learner;
  }

  @Override
  public Trans fit(final VecDataSet learn, final ClassicMulticlassLoss multiclassLoss) {
    final IntSeq labels = multiclassLoss.labels();
    final int countClasses = MCTools.countClasses(labels);

    final Func[] models = new Func[countClasses];
    for (int c = 0; c < countClasses; c++) {
      final Vec oneVsRestTarget = MCTools.extractClassForBinary(labels, c);
      final LLLogit llLogit = new LLLogit(oneVsRestTarget, learn.parent());
      final Trans model = learner.fit(learn, llLogit);

      models[c] = extractFunc(model);
    }

    return new JoinedBinClassModel(models);
  }

  public static Func extractFunc(final Trans model) {
    if (model instanceof Ensemble) {
      final Ensemble ensemble = (Ensemble) model;
      if (ensemble.last() instanceof Func) {
        return new FuncEnsemble<>(
            ArrayTools.map(ensemble.models, Func.class, new Computable() {
              @Override
              public Func compute(final Trans argument) {
                return (Func) argument;
              }
            }),
            ensemble.weights);
      } else {
        throw new IllegalArgumentException("Ensemble doesn't contain a Func");
      }
    } else if (model instanceof Func) {
      return (Func) model;
    } else {
      throw new IllegalArgumentException("Model doesn't contain a Func");
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy