com.expleague.ml.methods.greedyRegion.cnfMergeOptimization.GreedyMergedRegion Maven / Gradle / Ivy
package com.expleague.ml.methods.greedyRegion.cnfMergeOptimization;
import com.expleague.commons.math.MathTools;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.BFGrid;
import com.expleague.ml.impl.BFRowImpl;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.CNF;
import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.Binarize;
import com.expleague.ml.methods.greedyMergeOptimization.GreedyMergePick;
import com.expleague.ml.methods.greedyMergeOptimization.RegularizedLoss;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;
import static java.lang.Math.log;
/**
* User: noxoomo
* Date: 30/11/14
* Time: 15:31
*/
public class GreedyMergedRegion> extends VecOptimization.Stub {
public static final double CARDINALITY_FACTOR = 2;
protected final BFGrid grid;
private final double lambda;
public GreedyMergedRegion(final BFGrid grid, final double lambda) {
this.lambda = lambda;
this.grid = grid;
}
public GreedyMergedRegion(final BFGrid grid) {
this(grid, 2);
}
@Override
public CNF fit(final VecDataSet learn, final Loss loss) {
final List clauses = new ArrayList<>(10);
final CherryOptimizationSubsetMerger merger = new CherryOptimizationSubsetMerger(loss.statsFactory());
final GreedyMergePick pick = new GreedyMergePick<>(merger);
int[] points = loss instanceof WeightedLoss ? ((WeightedLoss) loss).points(): ArrayTools.sequence(0, learn.length());
final BinarizedDataSet bds = learn.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);
CherryOptimizationSubset last = new CherryOptimizationSubset(bds, loss.statsFactory(), new CNF.Clause(grid), points, 0);
final double totalPower = last.power();
final RegularizedLoss regLoss = new RegularizedLoss() {
@Override
public double target(CherryOptimizationSubset subset) {
return -loss.score(subset.stat);
}
@Override
public double regularization(CherryOptimizationSubset subset) {
final double cardinalityDiscount = CARDINALITY_FACTOR / (subset.cardinality() + CARDINALITY_FACTOR - 1);
final double regularization = (1 + lambda * log(subset.power() + 1));
return cardinalityDiscount * regularization;
}
@Override
public double score(CherryOptimizationSubset subset) {
final double score = loss.score(subset.stat);
return score * regularization(subset);
}
};
double score = regLoss.score(last);
while (true) {
final List models = init(bds, points, loss, last.cardinality());
final CherryOptimizationSubset best = pick.pick(models, regLoss);
System.out.print("\tClause " + clauses.size() + " score: " + regLoss.score(best) + " target: " + regLoss.target(best) + best.clause.toString());
if (score - regLoss.score(best) < MathTools.EPSILON)
break;
System.out.println(" accepted");
clauses.add(best.clause);
points = best.inside();
score = regLoss.score(best);
last = best;
}
System.out.println(" rejected");
System.out.println("Region weight: " + last.power() + " score: " + score + " target: " + regLoss.target(last));
return new CNF(clauses.toArray(new CNF.Clause[clauses.size()]), loss.bestIncrement(last.stat), grid);
}
static ThreadPoolExecutor exec = ThreadTools.createBGExecutor("Init CNF thread", -1);
private List init(final BinarizedDataSet bds, final int[] points, final Loss loss, final double cardinality) {
int binsTotal = 0;
for (int feature = 0; feature < grid.rows(); ++feature)
binsTotal += grid.row(feature).size() > 1 ? grid.row(feature).size() + 1 : 0;
final List result = new ArrayList<>(binsTotal);
final CountDownLatch latch = new CountDownLatch(binsTotal);
for (int feature = 0; feature < grid.rows(); ++feature) {
if (grid.row(feature).size() <= 1)
continue;
for (int bin = 0; bin <= grid.row(feature).size(); ++bin) {
final BFGrid.Row row = grid.row(feature);
final BitSet used = new BitSet(row.size() + 1);
used.set(bin);
exec.submit(() -> {
final CNF.Condition[] conditions = new CNF.Condition[1];
conditions[0] = new CNF.Condition(row, used);
final CNF.Clause clause = new CNF.Clause(grid, conditions);
final CherryOptimizationSubset subset = new CherryOptimizationSubset(bds, loss.statsFactory(), clause, points, cardinality);
synchronized (result) {
result.add(subset);
}
latch.countDown();
});
}
}
try {
latch.await();
} catch (InterruptedException e) {
//skip
}
return result;
}
}