com.expleague.ml.cli.modes.impl.InterpretModel Maven / Gradle / Ivy
package com.expleague.ml.cli.modes.impl;
import com.expleague.commons.math.MathTools;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.BFGrid;
import com.expleague.ml.impl.BFRowImpl;
import com.expleague.ml.impl.BinaryFeatureImpl;
import com.expleague.ml.meta.PoolFeatureMeta;
import com.expleague.commons.io.StreamTools;
import com.expleague.commons.math.Trans;
import com.expleague.commons.util.ArrayTools;
import com.expleague.ml.Binarize;
import com.expleague.ml.cli.builders.methods.grid.GridBuilder;
import com.expleague.ml.cli.modes.AbstractMode;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.data.tools.Pool;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.io.ModelsSerializationRepository;
import com.expleague.ml.models.ModelTools;
import com.expleague.ml.models.ObliviousTree;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TObjectDoubleHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.MissingArgumentException;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.function.Function;
import java.util.function.IntPredicate;
import java.util.stream.Collectors;
import static com.expleague.ml.cli.JMLLCLI.*;
/**
* User: solar
* Date: 16.05.17
*/
public class InterpretModel extends AbstractMode {
public void run(final CommandLine command) throws MissingArgumentException, IOException {
if (!command.hasOption(MODEL_OPTION))
throw new MissingArgumentException("Please provide 'MODEL_OPTION'");
if (!command.hasOption(GRID_OPTION))
throw new MissingArgumentException("Please provide 'GRID_OPTION'");
if (!command.hasOption(LEARN_OPTION))
throw new MissingArgumentException("Please provide 'LEARN_OPTION'");
final Pool> pool;
if (command.hasOption(JSON_FORMAT))
pool = DataTools.loadFromFile(command.getOptionValue(LEARN_OPTION));
else
pool = DataTools.loadFromFeaturesTxt(command.getOptionValue(LEARN_OPTION));
final BFGrid grid = BFGrid.CONVERTER.convertFrom(StreamTools.readFile(new File(command.getOptionValue(GRID_OPTION))));
boolean splits = false;
int topSplits = 100;
boolean histogram = false;
boolean mhistogram = false;
boolean linear = false;
final TIntArrayList histogramPath = new TIntArrayList();
final TIntArrayList mhistogramPath = new TIntArrayList();
if (command.hasOption(INTERPRET_MODE_OPTION)) {
final String value = command.getOptionValue(INTERPRET_MODE_OPTION);
final String[] split = value.split("/,/");
for (final String opt: split) {
if (opt.startsWith("splits")) {
splits = true;
if (opt.length() > "splits()".length())
topSplits = Integer.parseInt(opt.substring("splits(".length(), opt.length() - 1));
}
else if (opt.startsWith("histogram")) {
histogram = true;
if (opt.length() > "histogram()".length()) {
final String features = opt.substring("histogram(".length(), opt.length() - 1);
for (final String feature : features.split("/,/")) {
for (int f = 0; f < grid.rows(); f++) {
final BFGrid.Row row = grid.row(f);
final String fname = pool.features()[row.findex()].id();
if (feature.startsWith(fname)) {
final int bin = Integer.parseInt(feature.substring(fname.length() + 1, feature.length() - 1));
histogramPath.add(row.bf(bin).index());
break;
}
}
}
}
}
else if (opt.startsWith("mhistogram")) {
mhistogram = true;
if (opt.length() > "mhistogram()".length()) {
final String features = opt.substring("histogram(".length(), opt.length() - 1);
for (final String feature : features.split("/,/")) {
for (int f = 0; f < grid.rows(); f++) {
final BFGrid.Row row = grid.row(f);
final String fname = pool.features()[row.findex()].id();
if (feature.startsWith(fname)) {
final int bin = Integer.parseInt(feature.substring(fname.length() + 1, feature.length() - 1));
mhistogramPath.add(row.bf(bin).index());
break;
}
}
}
}
}
else if (opt.equals("linear")) {
linear = true;
}
}
}
final ModelsSerializationRepository serializationRepository;
final GridBuilder gridBuilder = new GridBuilder();
gridBuilder.setGrid(grid);
serializationRepository = new ModelsSerializationRepository(gridBuilder.create());
try {
final Function model = DataTools.readModel(command.getOptionValue(MODEL_OPTION), serializationRepository);
if (!(model instanceof Ensemble))
throw new IllegalArgumentException("Provided model is not ensemble");
final Ensemble ensemble = (Ensemble) model;
if (ensemble.size() == 0 )
throw new IllegalArgumentException("Provided ensemble is empty");
final ArrayList trees = new ArrayList<>();
for(final Trans component: ensemble.models) {
if (!(component instanceof ObliviousTree))
throw new IllegalArgumentException("This component type is not supported: " + component.getClass());
trees.add((ObliviousTree) component);
}
final Ensemble otEnsamble = new Ensemble<>(trees.toArray(new ObliviousTree[trees.size()]), ensemble.weights);
@SuppressWarnings("unchecked")
final ModelTools.CompiledOTEnsemble compile = ModelTools.compile(otEnsamble);
final List entries = new ArrayList<>(compile.getEntries());
TObjectIntMap entryCount = new TObjectIntHashMap<>();
{
final VecDataSet vds = pool.vecData();
final BinarizedDataSet bds = vds.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);
for (ModelTools.CompiledOTEnsemble.Entry entry : entries) {
int weight = 0;
for (int i = 0; i < vds.length(); i++) {
final int[] bfIndices = entry.getBfIndices();
final int length = bfIndices.length;
boolean fit = true;
for (int j = 0; j < length; j++) {
if (!grid.bf(bfIndices[j]).value(i, bds))
fit = false;
}
if (fit)
weight++;
}
entryCount.put(entry, weight);
}
}
entries.sort((a, b) -> Double.compare(Math.abs(b.getValue() * entryCount.get(b)), Math.abs(a.getValue() * entryCount.get(a))));
final int[] vfeatures;
{
final TIntHashSet valuableFeaturesSet = new TIntHashSet();
entries.stream().flatMapToInt(s -> Arrays.stream(s.getBfIndices())).forEach(valuableFeaturesSet::add);
vfeatures = valuableFeaturesSet.toArray();
}
if (splits)
topSplits(pool, grid, entries, vfeatures, topSplits);
if (histogram)
histograms(pool, grid, entries, histogramPath);
if (mhistogram)
mhistograms(pool, grid, entries, mhistogramPath);
if (linear || !(splits || histogram || mhistogram))
linearComponents(pool, grid, entries, entryCount);
}
catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
private void linearComponents(Pool> pool, BFGrid grid, List entries, TObjectIntMap entryCount) {
for (final ModelTools.CompiledOTEnsemble.Entry entry : entries) {
final StringBuilder builder = new StringBuilder();
builder.append(entryCount.get(entry));
builder.append("\t");
builder.append(entry.getValue());
final int[] bfIndices = entry.getBfIndices();
builder.append("\t");
for (int i = 0; i < bfIndices.length; i++) {
if (i > 0)
builder.append(", ");
final BFGrid.Feature binaryFeature = grid.bf(bfIndices[i]);
builder.append(pool.features()[binaryFeature.findex()].id()).append(" > ").append(ftoa(binaryFeature.condition()));
}
System.out.println(builder.toString());
}
}
private void histograms(Pool> pool, BFGrid grid, List entries, TIntArrayList histogramPath) {
final VecDataSet vds = pool.vecData();
final BinarizedDataSet bds = vds.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);
for (int i = 0; i < grid.rows(); i++) {
final BFGrid.Row row = grid.row(i);
final PoolFeatureMeta meta = pool.features()[row.findex()];
System.out.print(meta.id());
double total = 0;
final int[] path = histogramPath.toArray();
for (int bin = 0; bin < row.size(); bin++) {
final BFGrid.Feature binaryFeature = row.bf(bin);
final List vfEntries =
entries.parallelStream()
.filter(e -> ArrayTools.supset(e.getBfIndices(), path))
.filter(e -> ArrayTools.indexOf(binaryFeature.index(), e.getBfIndices()) >= 0)
.collect(Collectors.toList());
final double weight = expectedWeight(grid, pool.vecData(), bds, vfEntries);
total += weight;
if (Math.abs(weight) > MathTools.EPSILON)
System.out.print(String.format("\t%d:%.3g:%.4g", bin, row.condition(bin), total));
}
System.out.println();
}
}
private void mhistograms(Pool> pool, BFGrid grid, List entries, TIntArrayList histogramPath) {
final VecDataSet vds = pool.vecData();
final BinarizedDataSet bds = vds.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);
for (int i = 0; i < grid.rows(); i++) {
final BFGrid.Row row = grid.row(i);
final PoolFeatureMeta meta = pool.features()[row.findex()];
System.out.print(meta.id());
final int[] path = histogramPath.toArray();
for (int bin = 0; bin < row.size(); bin++) {
final BFGrid.Feature binaryFeature = row.bf(bin);
final List vfEntries =
entries.parallelStream()
.filter(e -> ArrayTools.supset(e.getBfIndices(), path))
.filter(e -> ArrayTools.indexOf(binaryFeature.index(), e.getBfIndices()) >= 0)
.collect(Collectors.toList());
final double weight = maxWeight(grid, pool.vecData(), bds, vfEntries);
if (Math.abs(weight) > MathTools.EPSILON)
System.out.print(String.format("\t%d:%.3g:%.4g", bin, row.condition(bin), weight));
}
System.out.println();
}
}
private void topSplits(Pool> pool, BFGrid grid, List entries, int[] vfeatures, int topSplits) {
final VecDataSet vds = pool.vecData();
final BinarizedDataSet bds = vds.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);
final TObjectDoubleHashMap weights = new TObjectDoubleHashMap<>();
final List splitQueue = new ArrayList<>();
final List split = new ArrayList<>();
Arrays.stream(vfeatures).mapToObj(vf -> new int[]{vf}).forEach(task -> {
weights.put(task, 100500);
splitQueue.add(task);
});
for (int i = 0; i < topSplits + vfeatures.length; i++) {
splitQueue.sort(Comparator.comparingDouble(a -> Math.abs(weights.get(a))));
final int[] vfset = splitQueue.remove(splitQueue.size() - 1);
final List vfEntries =
entries.parallelStream()
.filter(e -> ArrayTools.supset(e.getBfIndices(), vfset))
.collect(Collectors.toList());
final double value = expectedWeight(grid, vds, bds, vfEntries);
split.add(vfset);
weights.put(vfset, value);
vfEntries.stream().flatMapToInt(vfe -> {
final TIntArrayList variants = new TIntArrayList(vfset.length);
for (final int index : vfe.getBfIndices()) {
if (ArrayTools.indexOf(index, vfset) >= 0)
continue;
variants.add(index);
}
return Arrays.stream(variants.toArray());
}).sorted().filter(new IntPredicate() {
int prev = -1;
@Override
public boolean test(int value) {
boolean result = value != prev;
prev = value;
return result;
}
}).forEach(idx -> {
final int[] task = new int[vfset.length + 1];
System.arraycopy(vfset, 0, task, 0, vfset.length);
task[vfset.length] = idx;
weights.put(task, value);
splitQueue.add(task);
});
}
split.sort((a, b)->Double.compare(Math.abs(weights.get(b)), Math.abs(weights.get(a))));
for (int[] bfIndices : split) {
final StringBuilder builder = new StringBuilder();
builder.append(ftoa(weights.get(bfIndices)));
builder.append("\t");
for (int i = 0; i < bfIndices.length; i++) {
if (i > 0)
builder.append(", ");
final BFGrid.Feature binaryFeature = grid.bf(bfIndices[i]);
builder.append(pool.features()[binaryFeature.findex()].id()).append(" > ").append(ftoa(binaryFeature.condition()));
}
System.out.println(builder.toString());
}
}
private String ftoa(double v) {
return String.format(Locale.ENGLISH, "%.2f", v);
}
private double expectedWeight(BFGrid grid, VecDataSet vds, BinarizedDataSet bds, List vfEntries) {
double total = 0;
final int power = vds.length();
for (int j = 0; j < power; j++) {
final int finalJ = j;
final double value = vfEntries.stream()
.filter(entry -> {
final int[] bfIndices = entry.getBfIndices();
final int length = bfIndices.length;
for (int i = 0; i < length; i++) {
if (!grid.bf(bfIndices[i]).value(finalJ, bds))
return false;
}
return true;
})
.mapToDouble(ModelTools.CompiledOTEnsemble.Entry::getValue)
.sum();
total += value;
}
total /= power;
return total;
}
private double maxWeight(BFGrid grid, VecDataSet vds, BinarizedDataSet bds, List vfEntries) {
double total = 0;
final int power = vds.length();
for (int j = 0; j < power; j++) {
final int finalJ = j;
final double value = vfEntries.stream()
.filter(entry -> {
final int[] bfIndices = entry.getBfIndices();
final int length = bfIndices.length;
for (int i = 0; i < length; i++) {
if (!grid.bf(bfIndices[i]).value(finalJ, bds))
return false;
}
return true;
})
.mapToDouble(ModelTools.CompiledOTEnsemble.Entry::getValue)
.sum();
total = Math.max(value, total);
}
return total;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy