All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.multiclass.spoc.impl.CodingMatrixLearningGreedyParallels Maven / Gradle / Ivy

package com.expleague.ml.methods.multiclass.spoc.impl;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.Combinatorics;
import com.expleague.commons.util.Pair;
import com.expleague.ml.methods.multiclass.spoc.CMLHelper;

import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.*;

/**
 * User: qdeee
 * Date: 05.06.14
 */
public class CodingMatrixLearningGreedyParallels extends CodingMatrixLearningGreedy {
  private final ThreadPoolExecutor executor;
  private final int units;

  public CodingMatrixLearningGreedyParallels(final int k, final int l, final double lambdaC, final double lambdaR, final double lambda1) {
    super(k, l, lambdaC, lambdaR, lambda1);
    units = Runtime.getRuntime().availableProcessors();
    executor = new ThreadPoolExecutor(units, units, 5, TimeUnit.DAYS, new LinkedBlockingDeque());
  }

  private class ColumnSearch implements Callable> {
    final Mx mxB;
    final Mx S;
    final long start;
    final long count;

    private ColumnSearch(final Mx S, final Mx mxB, final long start, final long count) {
      this.mxB = mxB;
      this.start = start;
      this.count = count;
      this.S = S;
    }

    @Override
    public Pair call() throws Exception {
      final Combinatorics.Enumeration generator = new Combinatorics.PartialPermutations(2, mxB.rows());
      generator.skipN(start);

      double minLoss = Double.MAX_VALUE;
      int[] bestPerm = null;
      int pos = 0;
      while (pos++ < count && generator.hasNext()) {
        final int[] perm = generator.next();
        for (int i = 0; i < mxB.rows(); i++) {
          mxB.set(i, mxB.columns() - 1, 2 * perm[i] - 1);  //0 -> -1, 1 -> 1
        }
        if (CMLHelper.checkConstraints(mxB) && CMLHelper.checkColumnsIndependence(mxB)) {
          final double loss = calcLoss(mxB, S);
          if (loss < minLoss) {
            minLoss = loss;
            bestPerm = perm;
          }
        }
      }
      if (bestPerm == null) {
        throw new IllegalStateException("Not found appreciate column #" + (mxB.columns() - 1));
      }
      return Pair.create(minLoss, bestPerm);
    }
  }

  @Override
  public Mx findMatrixB(final Mx S) {
    final long partition = (long)(Math.pow(2, k) + units - 1) / units;
    final Mx mxB = new VecBasedMx(k, l);
    for (int j = 0; j < l; j++) {
      final List>> tasks = new LinkedList>>();
      for (int u = 0; u < units; u++) {
        final Mx mxBCopy = VecTools.copy(mxB.sub(0, 0, k, j + 1));
        final long start = u * partition;
        tasks.add(new ColumnSearch(S, mxBCopy, start, partition));
      }
      try {
        final List>> futures = executor.invokeAll(tasks);
        double totalMinLoss = Double.MAX_VALUE;
        int[] totalBestPerm = null;
        for (final Future> future : futures) {
          final Pair pair = future.get();
          final Double loss = pair.first;
          final int[] perm = pair.second;

          if (loss < totalMinLoss) {
            totalMinLoss = loss;
            totalBestPerm = perm;
          }
        }
        for (int i = 0; i < totalBestPerm.length; i++) {
          mxB.set(i, j, 2 * totalBestPerm[i] - 1);
        }
      } catch (InterruptedException e) {
        e.printStackTrace(); //who cares?
      } catch (ExecutionException e) {

      }
      System.out.println("Column " + j + " is over!");
    }
    return mxB;
  }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy