All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.greedyRegion.cnfMergeOptimization.GreedyMergedRegion Maven / Gradle / Ivy

package com.expleague.ml.methods.greedyRegion.cnfMergeOptimization;

import com.expleague.commons.math.MathTools;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.BFGrid;
import com.expleague.ml.impl.BFRowImpl;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.CNF;
import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.Binarize;
import com.expleague.ml.methods.greedyMergeOptimization.GreedyMergePick;
import com.expleague.ml.methods.greedyMergeOptimization.RegularizedLoss;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;

import static java.lang.Math.log;

/**
 * User: noxoomo
 * Date: 30/11/14
 * Time: 15:31
 */
public class GreedyMergedRegion> extends VecOptimization.Stub {
  public static final double CARDINALITY_FACTOR = 2;
  protected final BFGrid grid;
  private final double lambda;

  public GreedyMergedRegion(final BFGrid grid, final double lambda) {
    this.lambda = lambda;
    this.grid = grid;
  }

  public GreedyMergedRegion(final BFGrid grid) {
    this(grid, 2);
  }

  @Override
  public CNF fit(final VecDataSet learn, final Loss loss) {
    final List clauses = new ArrayList<>(10);
    final CherryOptimizationSubsetMerger merger = new CherryOptimizationSubsetMerger(loss.statsFactory());
    final GreedyMergePick pick = new GreedyMergePick<>(merger);
    int[] points = loss instanceof WeightedLoss ? ((WeightedLoss) loss).points(): ArrayTools.sequence(0, learn.length());
    final BinarizedDataSet bds = learn.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);
    CherryOptimizationSubset last = new CherryOptimizationSubset(bds, loss.statsFactory(), new CNF.Clause(grid), points, 0);
    final double totalPower = last.power();
    final RegularizedLoss regLoss = new RegularizedLoss() {
      @Override
      public double target(CherryOptimizationSubset subset) {
        return -loss.score(subset.stat);
      }

      @Override
      public double regularization(CherryOptimizationSubset subset) {
        final double cardinalityDiscount = CARDINALITY_FACTOR / (subset.cardinality() + CARDINALITY_FACTOR - 1);
        final double regularization = (1 + lambda * log(subset.power() + 1));
        return cardinalityDiscount * regularization;
      }

      @Override
      public double score(CherryOptimizationSubset subset) {
        final double score = loss.score(subset.stat);
        return score * regularization(subset);
      }
    };

    double score = regLoss.score(last);
    while (true) {
      final List models = init(bds, points, loss, last.cardinality());
      final CherryOptimizationSubset best = pick.pick(models, regLoss);

      System.out.print("\tClause " + clauses.size() + " score: " + regLoss.score(best) + " target: " + regLoss.target(best) + best.clause.toString());
      if (score - regLoss.score(best) < MathTools.EPSILON)
        break;
      System.out.println(" accepted");

      clauses.add(best.clause);
      points = best.inside();
      score = regLoss.score(best);
      last = best;
    }

    System.out.println(" rejected");

    System.out.println("Region weight: " + last.power() + " score: " + score + " target: " + regLoss.target(last));
    return new CNF(clauses.toArray(new CNF.Clause[clauses.size()]), loss.bestIncrement(last.stat), grid);
  }


  static ThreadPoolExecutor exec = ThreadTools.createBGExecutor("Init CNF thread", -1);

  private List init(final BinarizedDataSet bds, final int[] points, final Loss loss, final double cardinality) {
    int binsTotal = 0;
    for (int feature = 0; feature < grid.rows(); ++feature)
      binsTotal += grid.row(feature).size() > 1 ? grid.row(feature).size() + 1 : 0;

    final List result = new ArrayList<>(binsTotal);
    final CountDownLatch latch = new CountDownLatch(binsTotal);
    for (int feature = 0; feature < grid.rows(); ++feature) {
      if (grid.row(feature).size() <= 1)
        continue;
      for (int bin = 0; bin <= grid.row(feature).size(); ++bin) {
        final BFGrid.Row row = grid.row(feature);
        final BitSet used = new BitSet(row.size() + 1);
        used.set(bin);
        exec.submit(() -> {
          final CNF.Condition[] conditions = new CNF.Condition[1];
          conditions[0] = new CNF.Condition(row, used);
          final CNF.Clause clause = new CNF.Clause(grid, conditions);
          final CherryOptimizationSubset subset = new CherryOptimizationSubset(bds, loss.statsFactory(), clause, points, cardinality);
          synchronized (result) {
            result.add(subset);
          }
          latch.countDown();
        });
      }
    }
    try {
      latch.await();
    } catch (InterruptedException e) {
      //skip
    }
    return result;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy