All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multiclass.spoc.SPOCMethodClassic Maven / Gradle / Ivy

package com.expleague.ml.methods.multiclass.spoc;

import com.expleague.commons.math.Func;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.idxtrans.RowsPermutation;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.vectors.IndexTransVec;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.set.impl.VecDataSetImpl;
import com.expleague.ml.data.tools.MCTools;
import com.expleague.ml.loss.LLLogit;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.multiclass.MulticlassCodingMatrixModel;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.TIntList;
import gnu.trove.list.linked.TDoubleLinkedList;
import gnu.trove.list.linked.TIntLinkedList;
import gnu.trove.map.TIntObjectMap;

/**
 * User: qdeee
 * Date: 07.05.14
 */
public class SPOCMethodClassic extends VecOptimization.Stub {
  protected static final double MX_IGNORE_THRESHOLD = 0.1;

  protected final VecOptimization weak;
  protected final Mx codeMatrix;

  public SPOCMethodClassic(final Mx codeMatrix, final VecOptimization weak) {
    this.weak = weak;
    this.codeMatrix = VecTools.copy(codeMatrix);
    CMLHelper.normalizeMx(this.codeMatrix, MX_IGNORE_THRESHOLD);
  }

  @Override
  public MulticlassCodingMatrixModel fit(final VecDataSet learn, final BlockwiseMLLLogit llLogit) {
//    System.out.println("coding matrix: \n" + codeMatrix.toString());

    final TIntObjectMap indexes = MCTools.splitClassesIdxs(llLogit.labels());
    final int k = codeMatrix.rows();
    final int l = codeMatrix.columns();
    final Func[] binClassifiers = new Func[l];
    for (int j = 0; j < l; j++) {
      final TIntList learnIdxs = new TIntLinkedList();
      final TDoubleList target = new TDoubleLinkedList();
      for (int i = 0; i < k; i++) {
        final double code = codeMatrix.get(i, j);
        if (Math.abs(code) > MX_IGNORE_THRESHOLD) {
          final TIntList classIdxs = indexes.get(i);
          target.fill(target.size(), target.size() + classIdxs.size(), Math.signum(code));
          learnIdxs.addAll(classIdxs);
        }
      }

      final VecDataSet dataSet = new VecDataSetImpl(
          new VecBasedMx(
              learn.xdim(),
              new IndexTransVec(
                  learn.data(),
                  new RowsPermutation(learnIdxs.toArray(), learn.xdim())
              )
          ), learn);
      final LLLogit loss = new LLLogit(new ArrayVec(target.toArray()), learn);
      binClassifiers[j] = (Func) weak.fit(dataSet, loss);
    }
    return createModel(binClassifiers, learn, llLogit);
  }

  protected MulticlassCodingMatrixModel createModel(final Func[] binClass, final VecDataSet learnDS, final BlockwiseMLLLogit llLogit) {
    return new MulticlassCodingMatrixModel(codeMatrix, binClass, MX_IGNORE_THRESHOLD);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy