All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multilabel.MultiLabelConflictMulticlass Maven / Gradle / Ivy

package com.expleague.ml.methods.multilabel;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.VecIterator;
import com.expleague.commons.math.vectors.impl.mx.MxByRowsBuilder;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.models.multiclass.MCModel;
import com.expleague.commons.math.vectors.MxBuilder;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.VecBuilder;
import com.expleague.commons.util.Pair;
import com.expleague.ml.data.set.impl.VecDataSetImpl;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.loss.multilabel.ClassicMultiLabelLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.multilabel.ConflictThresholdMultiLabelModel;

/**
 * User: qdeee
 * Date: 22.03.15
 */
public class MultiLabelConflictMulticlass implements VecOptimization {
  private final VecOptimization weakMultiClass;
  private final double threshold;
  private final boolean allZeroesClassEnabled;

  public MultiLabelConflictMulticlass(final VecOptimization weakMultiClass, final double threshold, final boolean allZeroesClassEnabled) {
    this.weakMultiClass = weakMultiClass;
    this.threshold = threshold;
    this.allZeroesClassEnabled = allZeroesClassEnabled;
  }

  @Override
  public ConflictThresholdMultiLabelModel fit(final VecDataSet learn, final ClassicMultiLabelLoss classicMultiLabelLoss) {
    final Mx sourceFeatures = learn.data();
    final Mx sourceTargets = classicMultiLabelLoss.getTargets();
    final Pair conflictData = createConflictData(sourceTargets, sourceFeatures, allZeroesClassEnabled);

    final Mx featuresWithDuplicate = conflictData.getSecond();
    final Vec conflictTarget = conflictData.getFirst();

    final VecDataSet ds = new VecDataSetImpl(featuresWithDuplicate, null);
    final BlockwiseMLLLogit mllLogit = new BlockwiseMLLLogit(conflictTarget, null);

    final MCModel mcModel = (MCModel) weakMultiClass.fit(ds, mllLogit);
    return new ConflictThresholdMultiLabelModel(mcModel, threshold, allZeroesClassEnabled);
  }

  private static Pair createConflictData(final Mx targets, final Mx features, final boolean allZeroesClassEnabled) {
    final MxBuilder mxBuilder = new MxByRowsBuilder();
    final VecBuilder vecBuilder = new VecBuilder();
    for (int i = 0; i < features.rows(); i++) {
      final Vec instanceTargets = targets.row(i);
      final Vec instanceFeatures = features.row(i);

      final VecIterator targetIter = instanceTargets.nonZeroes();
      boolean allZeroesTarget = true;
      while (targetIter.advance()) {
        final int targetIndex = targetIter.index();
        final double targetValue = targetIter.value();

        if (targetValue > 0) {
          allZeroesTarget = false;
          vecBuilder.append(targetIndex);
          mxBuilder.add(instanceFeatures);
        }
      }
      if (allZeroesTarget && allZeroesClassEnabled) {
        vecBuilder.append(instanceTargets.dim());
        mxBuilder.add(instanceFeatures);
      }
    }

    return Pair.create(vecBuilder.build(), mxBuilder.build());
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy