All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multilabel.MultiLabelSubsetsMulticlass Maven / Gradle / Ivy

package com.expleague.ml.methods.multilabel;

import com.expleague.commons.func.Action;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.func.WeakListenerHolder;
import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxBuilder;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.MxByRowsBuilder;
import com.expleague.commons.math.vectors.impl.vectors.VecBuilder;
import com.expleague.commons.math.Trans;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.set.impl.VecDataSetImpl;
import com.expleague.ml.data.tools.MCTools;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.func.FuncJoin;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.loss.multilabel.ClassicMultiLabelLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.MultiClassModel;
import com.expleague.ml.models.multiclass.MCModel;
import com.expleague.ml.models.multilabel.MultiLabelSubsetsModel;
import gnu.trove.list.TIntList;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.procedure.TObjectIntProcedure;

import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.List;

/**
 * User: qdeee
 * Date: 23.03.15
 */
public class MultiLabelSubsetsMulticlass extends WeakListenerHolderImpl implements VecOptimization {
  private final VecOptimization weak;
  private final int minExamplesCount;

  public MultiLabelSubsetsMulticlass(final VecOptimization weak, final int minExamplesCount) {
    this.weak = weak;
    this.minExamplesCount = minExamplesCount;
  }

  @Override
  public MultiLabelSubsetsModel fit(final VecDataSet learn, final ClassicMultiLabelLoss multiLabelLoss) {
    final Mx targets = multiLabelLoss.getTargets();

    //build two mappings: uniq_labels(vec) -> classNum(int) and classNum(int) -> uniq_labels(vec)
    final TObjectIntMap vec2class = new TObjectIntHashMap<>();
    final Vec newTarget = new ArrayVec(targets.rows());
    for (int i = 0; i < targets.rows(); i++) {
      final Vec row = targets.row(i);
      final int classNumber = vec2class.adjustOrPutValue(row, 0, vec2class.size());
      newTarget.set(i, classNumber);
    }
    final Vec[] class2vec = new Vec[vec2class.size()];
    vec2class.forEachEntry(new TObjectIntProcedure() {
      @Override
      public boolean execute(final Vec labels, final int classNumber) {
        class2vec[classNumber] = labels;
        return true;
      }
    });


    //filter rare labels combinations
    final VecBuilder targetBuilder = new VecBuilder();
    final MxBuilder mxBuilder = new MxByRowsBuilder();
    final List filteredClass2Vec = new ArrayList<>();
    final TIntObjectMap classesIdxs = MCTools.splitClassesIdxs(VecTools.toIntSeq(newTarget));
    for (int clazz = 0, normalizedClass = 0; clazz < classesIdxs.size(); clazz++) {
      final TIntList indexes = classesIdxs.get(clazz);
      if (indexes.size() > minExamplesCount) {
        for (int i = 0; i < indexes.size(); i++) {
          targetBuilder.append(normalizedClass);
          mxBuilder.add(learn.at(i));
        }
        filteredClass2Vec.add(class2vec[clazz]);
        normalizedClass++;
      }
    }


    //fit model
    final VecDataSet filteredDs = new VecDataSetImpl(mxBuilder.build(), learn);
    final BlockwiseMLLLogit mllLogit = new BlockwiseMLLLogit(targetBuilder.build(), learn);

    //dirty hack for proxy listener
    List tmp = new ArrayList<>();
    final Vec[] filteredClass2VecArr = filteredClass2Vec.toArray(new Vec[filteredClass2Vec.size()]);
    if (weak instanceof WeakListenerHolder) {
      for (final WeakReference> listener : listeners) {
        final Action multiLabelAction = listener.get();
        final WeakListenerHolder weakListenerHolder = (WeakListenerHolder) weak;
        final Action weakAction = new Action() {
          @Override
          public void invoke(final Ensemble ensemble) {
            final FuncJoin join = MCTools.joinBoostingResult(ensemble);
            multiLabelAction.invoke(new MultiLabelSubsetsModel(new MultiClassModel(join), filteredClass2VecArr));
          }
        };
        tmp.add(weakAction);
        weakListenerHolder.addListener(weakAction);
      }
    }
    final Trans fitted = weak.fit(filteredDs, mllLogit);
    return new MultiLabelSubsetsModel(createMCModel(fitted), filteredClass2VecArr);
  }

  private static MCModel createMCModel(final Trans fitted) {
    if (fitted instanceof Ensemble && ((Ensemble) fitted).last() instanceof FuncJoin) {
      final FuncJoin funcJoin = MCTools.joinBoostingResult((Ensemble) fitted);
      return new MultiClassModel(funcJoin);
    } else if (fitted instanceof MCModel) {
      return  (MCModel) fitted;
    } else {
      throw new IllegalStateException("Can't convert fitted model to MCModel");
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy