All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.greedyMergeOptimization.GreedyMergePick Maven / Gradle / Ivy

package com.expleague.ml.methods.greedyMergeOptimization;

import com.expleague.commons.math.MathTools;
import com.expleague.commons.util.BestHolder;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.methods.greedyRegion.cnfMergeOptimization.CherryOptimizationSubset;

import java.text.NumberFormat;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * Created by noxoomo on 30/11/14.
 */

public class GreedyMergePick {
  static ThreadPoolExecutor exec = ThreadTools.createBGExecutor("Greedy merge pick thread", -1);
  private final MergeOptimization merger;

  public GreedyMergePick(final MergeOptimization merger) {
    this.merger = merger;
  }

  public Model pick(final List startModels, final RegularizedLoss loss) {
    final NumberFormat pp = MathTools.numberFormatter();
    if (startModels.isEmpty())
      throw new IllegalArgumentException("Models list must be not empty");

    final Comparator comparator = new Comparator() {
      @Override
      public int compare(final Model left, final Model right) {
        final int cmp = Double.compare(loss.score(left), loss.score(right));
        return cmp != 0 ? cmp : Integer.compare(left.index(), right.index());
      }
    };

    final TreeSet models = new TreeSet<>(comparator);
    models.addAll(startModels);
    while (models.size() > 1) {
      foo(loss, pp, models);
    }
    return models.first();
  }

  private void foo(final RegularizedLoss loss, final NumberFormat pp, final TreeSet models) {
    final Model current;
    {
      final Iterator iterator = models.descendingIterator();
      current = iterator.next();
      iterator.remove();
    }
    final CountDownLatch latch = new CountDownLatch(models.size());
    final double currentTarget = loss.target(current);
    final double currentReg = loss.regularization(current);
    final BestHolder bestHolder = new BestHolder<>();
//    System.out.print(current.toString() + " score: " + pp.format(currentScore));
    for (final Model model : models) {
      exec.submit(new Runnable() {
        @Override
        public void run() {
          try {
            final Model merged = merger.merge(current, model);
            if (merged.power() > model.power() && merged.power() > current.power()) {
              final double mergedTarget = loss.target(merged);
              final double mergedReg = loss.regularization(merged);
              final double modelTarget = loss.target(model);
              final double modelReg = loss.regularization(model);
              final double gain = (merged.power() * ((modelTarget + currentTarget) / (model.power() + current.power())) - mergedTarget);
              bestHolder.update(merged, gain);
            }
          }
          catch(Throwable th) {
            th.printStackTrace();
          }
          finally {
            latch.countDown();
          }
        }
      });
    }

    try {
      latch.await();
    } catch (InterruptedException e) {
      //skip
    }
    final Model best = bestHolder.getValue();
    if (bestHolder.getScore() > 0) {
      models.add(best);
    }
//    else System.out.print(" WASTED ");
//    if (bestHolder.filled())
//      System.out.println(" -> " + best.toString() + " score: " + pp.format(loss.score(best)) + " gain: " + bestHolder.getScore());
//    else
//      System.out.println();
  }
}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy