All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multiclass.spoc.ECOCCombo Maven / Gradle / Ivy

package com.expleague.ml.methods.multiclass.spoc;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.vectors.IndexTransVec;
import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.idxtrans.RowsPermutation;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.Combinatorics;
import com.expleague.commons.util.Pair;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.set.impl.VecDataSetImpl;
import com.expleague.ml.data.tools.MCTools;
import com.expleague.ml.loss.LLLogit;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.multiclass.MulticlassCodingMatrixModel;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.TIntList;
import gnu.trove.list.linked.TDoubleLinkedList;
import gnu.trove.list.linked.TIntLinkedList;
import gnu.trove.map.TIntObjectMap;
import org.jetbrains.annotations.NotNull;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.*;

/**
 * User: qdeee
 * Date: 14.11.14
 */
public class ECOCCombo extends WeakListenerHolderImpl implements VecOptimization {
  private static final int UNITS = Runtime.getRuntime().availableProcessors();
  private static final double MX_IGNORE_THRESHOLD = 0.1;

  private final ExecutorService executor;

  private final int k;
  private final int l;
  private final double lambdaC;
  private final double lambdaR;
  private final double lambda1;
  private final Mx S;

  private final VecOptimization weak;

  public ECOCCombo(
      final int k,
      final int l,
      final double lambdaC,
      final double lambdaR,
      final double lambda1,
      final @NotNull Mx s,
      final @NotNull VecOptimization weak
  ) {
    this.executor = Executors.newFixedThreadPool(UNITS);
    this.k = k;
    this.l = l;
    this.lambdaC = lambdaC;
    this.lambdaR = lambdaR;
    this.lambda1 = lambda1;
    this.S = s;
    this.weak = weak;
  }

  protected double calcLoss(final Mx B, final Mx S) {
    double result = 0;
    final Mx mult = MxTools.multiply(B, MxTools.transpose(B));
    result -= MxTools.trace(MxTools.multiply(mult, S));
    result += lambdaR * VecTools.sum(mult);
    result += lambdaC * VecTools.sum2(B);
    result += lambda1 * VecTools.l1(B);
    return result;
  }

  @Override
  public MulticlassCodingMatrixModel fit(final VecDataSet learn, final BlockwiseMLLLogit mllLogit) {
    final long permutationsForOneProcessor = (long)(Math.pow(2, k) + UNITS - 1) / UNITS;
    final Mx mxB = new VecBasedMx(k, l);
    final List classifiers = new ArrayList<>(l);
    final TIntObjectMap indexes = MCTools.splitClassesIdxs(mllLogit.labels());

    for (int j = 0; j < l; j++) {
      //find column
      final List>> tasks = new LinkedList>>();
      for (int u = 0; u < UNITS; u++) {
        final Mx mxBCopy = VecTools.copy(mxB.sub(0, 0, k, j + 1));
        final long start = u * permutationsForOneProcessor;
        tasks.add(new ColumnSearch(S, mxBCopy, start, permutationsForOneProcessor));
      }
      try {
        final List>> futures = executor.invokeAll(tasks);
        double totalMinLoss = Double.MAX_VALUE;
        int[] totalBestPerm = null;
        for (final Future> future : futures) {
          final Pair pair = future.get();
          final Double loss = pair.first;
          final int[] perm = pair.second;

          if (loss < totalMinLoss) {
            totalMinLoss = loss;
            totalBestPerm = perm;
          }
        }
        for (int i = 0; i < totalBestPerm.length; i++) {
          mxB.set(i, j, 2 * totalBestPerm[i] - 1);
        }
      } catch (InterruptedException | ExecutionException e) {
        e.printStackTrace(); //who cares?
      }

      //fit column classifier
      final TIntList learnIdxs = new TIntLinkedList();
      final TDoubleList target = new TDoubleLinkedList();
      for (int i = 0; i < k; i++) {
        final double code = mxB.get(i, j);
        if (Math.abs(code) > MX_IGNORE_THRESHOLD) {
          final TIntList classIdxs = indexes.get(i);
          target.fill(target.size(), target.size() + classIdxs.size(), Math.signum(code));
          learnIdxs.addAll(classIdxs);
        }
      }

      final VecDataSet dataSet = new VecDataSetImpl(
          new VecBasedMx(
              learn.xdim(),
              new IndexTransVec(
                  learn.data(),
                  new RowsPermutation(learnIdxs.toArray(), learn.xdim())
              )
          ), learn
      );
      final LLLogit loss = new LLLogit(new ArrayVec(target.toArray()), learn);
      classifiers.add((Func) weak.fit(dataSet, loss));

      invoke(
          new MulticlassCodingMatrixModel(
            mxB.sub(0, 0, k, j + 1),
            classifiers.toArray(new Func[j + 1]),
            MX_IGNORE_THRESHOLD)
      );
    }
    executor.shutdown();

    return new MulticlassCodingMatrixModel(mxB, classifiers.toArray(new Func[classifiers.size()]), MX_IGNORE_THRESHOLD);
  }

  private class ColumnSearch implements Callable> {
    final Mx mxB;
    final Mx S;
    final long start;
    final long count;

    private ColumnSearch(final Mx S, final Mx mxB, final long start, final long count) {
      this.mxB = mxB;
      this.start = start;
      this.count = count;
      this.S = S;
    }

    @Override
    public Pair call() throws Exception {
      final Combinatorics.Enumeration generator = new Combinatorics.PartialPermutations(2, mxB.rows());
      generator.skipN(start);

      double minLoss = Double.MAX_VALUE;
      int[] bestPerm = null;
      int pos = 0;
      while (pos++ < count && generator.hasNext()) {
        final int[] perm = generator.next();
        for (int i = 0; i < mxB.rows(); i++) {
          mxB.set(i, mxB.columns() - 1, 2 * perm[i] - 1);  //0 -> -1, 1 -> 1
        }
        if (CMLHelper.checkConstraints(mxB) && CMLHelper.checkColumnsIndependence(mxB)) {
          final double loss = calcLoss(mxB, S);
          if (loss < minLoss) {
            minLoss = loss;
            bestPerm = perm;
          }
        }
      }
      if (bestPerm == null) {
        throw new IllegalStateException("Not found appreciate column #" + (mxB.columns() - 1));
      }
      return Pair.create(minLoss, bestPerm);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy