All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.expleague.ml.methods.nn.NeuralTreesOptimization Maven / Gradle / Ivy
package com.expleague.ml.methods.nn;
import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.BFGrid;
import com.expleague.ml.BlockedTargetFunc;
import com.expleague.ml.GridTools;
import com.expleague.ml.ProgressHandler;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.set.impl.VecDataSetImpl;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.BootstrapOptimization;
import com.expleague.ml.methods.GradientBoosting;
import com.expleague.ml.methods.Optimization;
import com.expleague.ml.methods.trees.GreedyObliviousTree;
import com.expleague.ml.models.nn.ConvNet;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class NeuralTreesOptimization implements Optimization {
private Vec x;
private final int numIterations;
private final int nSampleBuildTree;
private final ConvNet nn;
private final FastRandom rng;
private Loss loss;
private final TreesLoss curLoss = new TreesLoss();
private final int[] sampleIdxs;
public NeuralTreesOptimization(int numIterations, int nSampleBuildTree, ConvNet nn, FastRandom rng) {
this.numIterations = numIterations;
this.nSampleBuildTree = nSampleBuildTree;
this.nn = nn;
this.rng = rng;
sampleIdxs = new int[nSampleBuildTree];
}
@Override
public Function fit(VecDataSet learn, Loss loss) {
this.loss = loss;
final Mx highFeatures = new VecBasedMx(nSampleBuildTree, nn.ydim());
final Ensemble[] ensemble = {null};
for (int iter = 0; iter < numIterations; iter++) {
for (int i = 0; i < nSampleBuildTree; i++) {
sampleIdxs[i] = rng.nextInt(learn.length());
final Vec result = nn.apply(learn.data().row(sampleIdxs[i]));
VecTools.assign(highFeatures.row(i), result);
}
final VecDataSetImpl highLearn = new VecDataSetImpl(highFeatures, null);
final BFGrid grid = GridTools.medianGrid(highLearn, 32);
final GreedyObliviousTree> weak = new GreedyObliviousTree<>
(grid, 6);
final GradientBoosting boosting = new GradientBoosting<>(
new BootstrapOptimization<>(weak, rng), L2.class, 1000, 0.005);
final Consumer counter = new ProgressHandler() {
int index = 0;
@Override
public void accept(Trans partial) {
if (index % 100 == 0)
System.out.println("boost [" + (index++) + "]");
}
};
boosting.addListener(counter);
final Ensemble curEnsemble = boosting.fit(highLearn, (Loss) curLoss);
final Vec result = curEnsemble.transAll(learn.data()).vec();
final double value = curLoss.value(result);
System.out.println("[" + iter + "], loss: " + value);
final Vec grad = new ArrayVec(nn.wdim());
final Vec x = nn.weights();
final double sgdStep = 1e-3;
for (int i = 0; i < nSampleBuildTree; i++) {
final Vec treeOut = curEnsemble.trans(highLearn.data().row(i));
nn.gradientTo(learn.data().row(sampleIdxs[i]), x, new TargetByTreeOut(sampleIdxs[i], treeOut), grad);
VecTools.append(x, VecTools.scale(grad, -sgdStep));
}
if (iter == numIterations - 1) {
ensemble[0] = curEnsemble;
}
}
return argument -> {
Vec result = nn.apply(argument, x);
return ensemble[0].trans(result);
};
}
private class TargetByTreeOut extends FuncC1.Stub {
private final int blockIdx;
private final Vec treeOut;
TargetByTreeOut(int blockIdx, Vec treeOut) {
this.blockIdx = blockIdx;
this.treeOut = treeOut;
}
@Override
public double value(Vec x) {
throw new UnsupportedOperationException();
}
@Override
public Vec gradientTo(Vec x, Vec to) {
return ((FuncC1) loss.block(blockIdx)).gradientTo(treeOut, to);
}
@Override
public int dim() {
return 0;
}
}
private class TreesLoss extends Trans.Stub implements BlockedTargetFunc {
@Override
public FuncC1 block(int index) {
return (FuncC1) loss.block(sampleIdxs[index]);
}
@Override
public int blocksCount() {
return nSampleBuildTree;
}
@Override
public int xdim() {
return nSampleBuildTree;
}
@Override
public Trans gradient() {
final List collect = IntStream.range(0, sampleIdxs.length)
.mapToObj(this::block)
.collect(Collectors.toList());
return new Trans.Stub() {
@Override
public Vec transTo(Vec argument, Vec to) {
Vec next = new ArrayVec(to.dim());
for (int i = 0; i < to.dim(); i++) {
to.set(i, collect.get(i).gradientTo(argument, next).get(0));
}
return to;
}
@Override
public int xdim() {
return TreesLoss.this.xdim();
}
@Override
public int ydim() {
return TreesLoss.this.xdim();
}
};
}
@Override
public DataSet> owner() {
throw new UnsupportedOperationException();
}
@Override
public double value(Vec x) {
return IntStream.range(0, sampleIdxs.length)
.mapToObj(this::block)
.map(func -> func.value(x))
.collect(Collectors.averagingDouble(i -> i));
}
@Override
public int dim() {
throw new UnsupportedOperationException();
}
@Override
public int ydim() {
throw new UnsupportedOperationException();
}
}
}