import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.IntSeq;
import java.util.function.Consumer;
import java.util.function.Function;
public class NeuralTreesOptimization implements Optimization {
private final NetworkBuilder.Network network;
private int numIterations;
private int nSampleBuildTree;
private final FastRandom rng;
private int sgdIterations;
private int numTrees;
private final int batchSize;
private final PrintStream debug;
private double sgdStep;
private double boostingStep;
private VecDataSet test;
private BlockwiseMLLLogit testLoss;
public NeuralTreesOptimization(int numIterations, int nSampleBuildTree, int sgdIterations, int batchSize,
double sgdStep, int numTrees, double boostingStep,
NetworkBuilder.Network nn, FastRandom rng, PrintStream debug) {
this.numIterations = numIterations;
this.nSampleBuildTree = nSampleBuildTree;
this.sgdIterations = sgdIterations;
this.batchSize = batchSize;
this.numTrees = numTrees;
this.boostingStep = boostingStep;
this.rng = rng;
this.debug = debug;
this.sgdStep = sgdStep; = nn;
debug.printf("parameters:\n" +
" numIterations = %d;\n" +
" nSampleBuildTree = %d;\n" +
" sgdIterations = %d;\n" +
" batchSize = %d;\n" +
" numTrees = %d;\n" +
" boostingStep = %f;\n" +
" sgdStep = %f;\n", numIterations, nSampleBuildTree, sgdIterations,
batchSize, numTrees, boostingStep, sgdStep);
public void setTest(VecDataSet test, IntSeq labels) {
this.test = test;
testLoss = new BlockwiseMLLLogit(labels, test);
public Function fit(VecDataSet learn, BlockwiseMLLLogit loss) {
final Vec weights = new ArrayVec(network.wdim());
DataNormalizer normalizer;
ConvNet nn = new ConvNet(network, weights);
Mx highData = new VecBasedMx(learn.length(), network.ydim());
for (int i = 0; i < learn.length(); i++) {
VecTools.assign(highData.row(i), nn.apply(;
normalizer = new DataNormalizer(highData, 0, 3.);
for (int iter = 0; iter < numIterations; iter++) {
final ConvNet nn = new ConvNet(network, weights);
final HighLevelDataset highLearn = HighLevelDataset.sampleFromDataset(learn, normalizer, loss, nn, nSampleBuildTree, rng);
final HighLevelDataset highTest = HighLevelDataset.sampleFromDataset(test, normalizer, testLoss, nn, nSampleBuildTree, rng);
final Ensemble ensemble = fitBoosting(highLearn, highTest);
Vec L = new ArrayVec(nn.wdim());
Vec grad = new ArrayVec(nn.wdim());
Vec step = new ArrayVec(nn.wdim());
VecTools.fill(L, this.sgdStep);
final Vec prevGrad = new ArrayVec(nn.wdim());
final Vec partial = new ArrayVec(nn.wdim());
for (int sgdIter = 0; sgdIter < sgdIterations; sgdIter++) {
VecTools.fill(grad, 0);
for (int i = 0; i < batchSize; i++) {
final int sampleIdx = rng.nextInt(learn.length());
Vec apply = nn.apply(;
normalizer.transTo(apply, apply);
final Vec treeGrad = ensembleGradient(nn, ensemble, loss, apply, sampleIdx);
final Vec baseVec =;
nn.gradientTo(baseVec, new TargetByTreeOut(treeGrad), partial);
VecTools.append(grad, partial);
// // FIXME
// final int lastLayerWStart = nn.wdim() - 500 * 80;
// for (int i = 0; i < lastLayerWStart; i++) {
// grad.set(i, 0);
// }
VecTools.assign(step, grad);
VecTools.scale(step, L);
VecTools.incscale(weights, step, sgdStep);
for (int i = 0; i < L.dim(); i++) {
L.set(i, Math.min(L.get(i) / 0.99, 1 / Math.abs(grad.get(i))));
if (sgdIter % 20 == 0 && sgdIter != 0) {
final Mx resultTrain = ensemble.transAll(;
final Mx resultTest = ensemble.transAll(;
final double lTrain = highLearn.loss().value(resultTrain);
final double lTest = highTest.loss().value(resultTest);
final double accTest = accuracy(highTest.loss(), resultTest);
debug.println("sgd [" + (sgdIter) + "], loss(train): " + lTrain +
" loss(test): " + lTest + " acc(test): " + accTest);
debug.println("Grad alignment: " + VecTools.cosine(prevGrad, grad));
VecTools.assign(prevGrad, grad);
final ConvNet nn = new ConvNet(network, weights);
final HighLevelDataset allLearn = HighLevelDataset.createFromDs(learn, normalizer, loss, nn);
final HighLevelDataset allTest = HighLevelDataset.createFromDs(test, normalizer, testLoss, nn);
final Ensemble ensemble = fitBoosting(allLearn, allTest);
return argument -> {
Vec result = nn.apply(argument, weights);
return ensemble.trans(result);
private Vec ensembleGradient(ConvNet nn, Ensemble ensemble, BlockwiseMLLLogit loss, Vec x, int blockId) {
final Vec ensembleGrad = new ArrayVec(nn.ydim());
final Vec lossGrad = new ArrayVec(loss.blockSize());
final Vec treeOut = ensemble.trans(x);
loss.gradient(treeOut, lossGrad, blockId);
final Vec currentWeights = new ArrayVec(loss.blockSize());
final Vec grad = new ArrayVec(nn.ydim());
for (int i = 0; i < ensemble.models.length; i++) {
VecTools.fill(grad, 0.);
final ScaledVectorFunc model = (ScaledVectorFunc) ensemble.models[i];
VecTools.assign(currentWeights, model.weights);
VecTools.scale(currentWeights, lossGrad);
((ProbRegion) model.function).gradientTo(x, grad);
// {
// final Vec realGrad = new ArrayVec(nn.ydim());
// double value = model.function.value(x);
// for (int j = 0; j < grad.dim(); j++) {
// x.adjust(j, epsilon);
// double valuePrime = model.function.value(x);
// realGrad.set(j, (valuePrime - value) / epsilon);
// x.adjust(j, -epsilon);
// }
// }
VecTools.scale(grad, ensemble.weights.get(i) * VecTools.sum(currentWeights));
VecTools.append(ensembleGrad, grad);
return ensembleGrad;
private Ensemble fitBoosting(HighLevelDataset learn, HighLevelDataset test) {
final BFGrid grid = GridTools.medianGrid(learn.vec(), 32);
final GreedyProbLinearRegion> weak = new GreedyProbLinearRegion<>(grid, 7);
final BootstrapOptimization bootstrap = new BootstrapOptimization(weak, rng);
final GradientBoosting boosting = new GradientBoosting<>(new GradFacMulticlass(
bootstrap, new StochasticALS(rng, 1000.), L2Reg.class, false), L2Reg.class, numTrees, boostingStep);
final Consumer counter = new ProgressHandler() {
int index = 0;
Mx currentLearn = new VecBasedMx(, learn.loss.blockSize());
Mx currentTest = new VecBasedMx(, learn.loss.blockSize());;
public void accept(Trans partial) {
Mx resultTrain = null;
Mx resultTest = null;
if (partial instanceof Ensemble) {
Ensemble ensemble = (Ensemble) partial;
Trans last = ensemble.last();
VecTools.incscale(currentLearn, last.transAll(, ensemble.wlast());
resultTrain = currentLearn;
VecTools.incscale(currentTest, last.transAll(, ensemble.wlast());
resultTest = currentTest;
if (index % 100 == 0) {
if (resultTest == null || resultTrain == null){
resultTrain = partial.transAll(;
resultTest = partial.transAll(;
final double lTrain = learn.loss().value(resultTrain);
final double lTest = test.loss().value(resultTest);
final double accTest = accuracy(test.loss(), resultTest);
debug.println("boost [" + (index) + "], loss(train): " + lTrain +
" loss(test): " + lTest + " acc(test): " + accTest);
final Ensemble ensemble =, learn.loss());
final Vec result = ensemble.transAll(;
final double curLossValue = learn.loss().value(result);
debug.println("ensemble loss: " + curLossValue);
return ensemble;
private static double accuracy(BlockwiseMLLLogit loss, Mx results) {
final Vec predict = new ArrayVec(results.rows());
IntStream.range(0, results.rows()).parallel().forEach(i -> {
final Vec prob = loss.prob(results.row(i), new ArrayVec(loss.blockSize() + 1));
predict.set(i, VecTools.argmax(prob));
final IntSeq labels = loss.labels();
int acc = 0;
for (int i = 0; i < predict.dim(); i++) {
acc += predict.get(i) == labels.intAt(i) ? 1 : 0;
return ((double) acc) / results.rows();
private static class HighLevelDataset {
private final ConvNet nn;
private final VecDataSet base;
private final BlockwiseMLLLogit loss;
private final int[] sampleIdxs;
private DataNormalizer normalizer;
private Mx highData;
private HighLevelDataset(Mx highData, DataNormalizer normalizer, ConvNet nn, VecDataSet base, BlockwiseMLLLogit loss, int[] sampleIdxs) {
this.highData = highData;
this.normalizer = normalizer;
this.nn = nn;
this.base = base;
this.loss = loss;
this.sampleIdxs = sampleIdxs;
static HighLevelDataset sampleFromDataset(VecDataSet ds, BlockwiseMLLLogit loss, ConvNet nn, int numSamples, FastRandom rng) {
return sampleFromDataset(ds, null, loss, nn, numSamples, rng);
static HighLevelDataset sampleFromDataset(VecDataSet ds, DataNormalizer normalizer, BlockwiseMLLLogit loss, ConvNet nn, int numSamples, FastRandom rng) {
Mx highData = new VecBasedMx(numSamples, nn.ydim());
final int[] sampleIdx = new int[numSamples];
for (int i = 0; i < numSamples; i++) {
sampleIdx[i] = rng.nextInt(ds.length());
final Vec result = nn.apply([i]));
VecTools.assign(highData.row(i), result);
if (normalizer == null)
normalizer = new DataNormalizer(highData, 0, 3.);
highData = normalizer.transAll(highData, true);
final Vec target = new ArrayVec(numSamples);
IntStream.range(0, numSamples).forEach(idx -> target.set(idx, loss.label(sampleIdx[idx])));
final BlockwiseMLLLogit newLoss = new BlockwiseMLLLogit(target, ds);
return new HighLevelDataset(highData, normalizer, nn, ds, newLoss, sampleIdx);
static HighLevelDataset createFromDs(VecDataSet ds, DataNormalizer normalizer, BlockwiseMLLLogit loss, ConvNet nn) {
Mx highData = new VecBasedMx(ds.length(), nn.ydim());
for (int i = 0; i < ds.length(); i++) {
final Vec result = nn.apply(;
VecTools.assign(highData.row(i), result);
if (normalizer == null) {
normalizer = new DataNormalizer(highData, 0., 3.);
highData = normalizer.transAll(highData, true);
return new HighLevelDataset(highData, normalizer, nn, ds, loss, IntStream.range(0, ds.length()).toArray());
static HighLevelDataset createFromDs(VecDataSet ds, BlockwiseMLLLogit loss, ConvNet nn) {
return createFromDs(ds, null, loss, nn);
public Mx data() {
return highData;
public void setNormalizer(DataNormalizer normalizer) {
this.normalizer = normalizer;
public DataNormalizer getNormalizer() {
return normalizer;
public BlockwiseMLLLogit loss() {
return loss;
public VecDataSet vec() {
return new VecDataSetImpl(highData, base);
public Vec baseVecById(int id) {
public void update() {
for (int i = 0; i < highData.rows(); i++) {
final Vec result = nn.apply([i]));
VecTools.assign(highData.row(i), result);
highData = normalizer.transAll(highData, true);
public Vec getCached(int sampleIdx) {
return highData.row(sampleIdx);
public Vec get(int idx) {
final Vec apply = nn.apply([idx]));
normalizer.transTo(apply, apply);
return apply;
private static class DataNormalizer extends Trans.Stub {
private final Vec mean;
private final Vec disp;
private final double newMean;
private final double newDisp;
DataNormalizer(Mx data, double newMean, double newDisp) {
this.newMean = newMean;
this.newDisp = newDisp;
final int featuresDim = data.columns();
mean = VecTools.fill(new ArrayVec(featuresDim), 0.);
disp = VecTools.fill(new ArrayVec(featuresDim), 0.);
for (int i = 0; i < data.rows(); i++) {
final Vec row = data.row(i);
VecTools.append(mean, row);
appendSqr(disp, row, 1.);
VecTools.scale(mean, 1. / data.rows());
VecTools.scale(disp, 1. / data.rows());
appendSqr(disp, mean, -1.);
for (int i = 0; i < featuresDim; i++) {
double v = disp.get(i) == 0. ? 1. : disp.get(i);
disp.set(i, v);
public Vec transTo(Vec x, Vec to) {
if (x.dim() != mean.dim()) {
throw new IllegalArgumentException();
for (int i = 0; i < x.dim(); i++) {
final double v = (x.get(i) - mean.get(i)) / Math.sqrt(disp.get(i)) * newDisp + newMean;
to.set(i, v);
return to;
private void appendSqr(Vec to, Vec who, double alpha) {
for (int i = 0; i < to.dim(); i++) {
final double v = who.get(i);
to.adjust(i, alpha * v * v);
public int xdim() {
return mean.dim();
public int ydim() {
return xdim();
private class TargetByTreeOut extends FuncC1.Stub {
private final Vec treesGradient;
TargetByTreeOut(Vec gradient) {
this.treesGradient = gradient;
public double value(Vec x) {
throw new UnsupportedOperationException();
public Vec gradientTo(Vec x, Vec to) {
VecTools.assign(to, treesGradient);
return to;
public int dim() {
return treesGradient.dim();