All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.data.tools.MultiLabelTools Maven / Gradle / Ivy

package com.expleague.ml.data.tools;

import com.expleague.ml.func.FuncEnsemble;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.math.Func;
import com.expleague.ml.cli.output.printers.MultiLabelLogitProgressPrinter;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.func.FuncJoin;
import com.expleague.ml.models.multilabel.MultiLabelBinarizedModel;
import com.expleague.ml.models.multilabel.MultiLabelModel;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

/**
 * User: qdeee
 * Date: 20.03.15
 */
public final class MultiLabelTools {
  public static MultiLabelModel extractMultiLabelModel(final Function model) {
    MultiLabelModel multiLabelModel = null;
    if (model instanceof MultiLabelModel) {
      multiLabelModel = (MultiLabelModel) model;
    } else if (model instanceof Ensemble && ((Ensemble) model).last() instanceof FuncJoin) {
      final FuncJoin joined = MCTools.joinBoostingResult((Ensemble) model);
      multiLabelModel = new MultiLabelBinarizedModel(joined);
    }

    if (multiLabelModel == null) {
      throw new UnsupportedOperationException("Can't extract multi-label model");
    }

    return multiLabelModel;
  }

  public static void makeOVRReport(final Pool learn, final Pool test, final Function model, final int period) {
    if (model instanceof MultiLabelBinarizedModel) {
      final FuncJoin internModel = ((MultiLabelBinarizedModel) model).getInternModel();

      final MultiLabelLogitProgressPrinter progressPrinter = new MultiLabelLogitProgressPrinter(learn, test, period);
      final FuncEnsemble[] perLabelModels = ArrayTools.map(internModel.dirs, FuncEnsemble.class, argument -> (FuncEnsemble) argument);
      final int ensembleSize = perLabelModels[0].size();
      final int labelsCount = perLabelModels.length;

      final double step = perLabelModels[0].weights.get(0);
      final List weakModels = new ArrayList<>();
      for (int t = 0; t < ensembleSize; t++) {
        final Func[] functions = new Func[labelsCount];
        for (int c = 0; c < labelsCount; c++) {
          functions[c] = perLabelModels[c].models[t];
        }
        weakModels.add(new FuncJoin(functions));
        progressPrinter.accept(new Ensemble<>(weakModels, step));
      }
    } else {
      throw new UnsupportedOperationException("Can't extract intern FuncJoin model");
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy