com.expleague.ml.methods.multiclass.gradfac.FMCBoosting Maven / Gradle / Ivy
package com.expleague.ml.methods.multiclass.gradfac;
import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.RowsVecArrayMx;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.Pair;
import com.expleague.commons.util.logging.Interval;
import com.expleague.ml.BFGrid;
import com.expleague.ml.Binarize;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.factorization.Factorization;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.func.ScaledVectorFunc;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.SatL2;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.blockwise.BlockwiseMLLLogit;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.ObliviousTree;
import org.apache.commons.math3.util.FastMath;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import static java.lang.Math.exp;
/**
* Experts League
* Created by solar on 05.05.17.
*/
public class FMCBoosting extends WeakListenerHolderImpl implements VecOptimization {
protected final VecOptimization weak;
private final Class extends L2> factory;
private final Factorization factorize;
private final int iterationsCount;
private final double step;
private final boolean lazyCursor;
private final int ensembleSize;
private final boolean isGbdt;
private BinarizedDataSet bds = null;
private final FastRandom rfRnd = new FastRandom(13);
private VecDataSet valid;
private BlockwiseMLLLogit validTarget;
private int bestIterCount = 0;
private double bestAccuracy;
private int earlyStoppingRounds = 0;
public FMCBoosting(final Factorization factorize, final VecOptimization weak, final int iterationsCount, final double step, final boolean lazyCursor) {
this(factorize, weak, SatL2.class, iterationsCount, step, lazyCursor, 1, false);
}
public FMCBoosting(final Factorization factorize, final VecOptimization weak, final int iterationsCount, final double step) {
this(factorize, weak, SatL2.class, iterationsCount, step, false, 1, false);
}
public FMCBoosting(final Factorization factorize, final VecOptimization weak, final int iterationsCount, final double step, final int ensembleSize) {
this(factorize, weak, SatL2.class, iterationsCount, step, false, ensembleSize, false);
}
public FMCBoosting(final Factorization factorize, final VecOptimization weak, final Class extends L2> factory, final int iterationsCount, final double step, final int ensembleSize, final boolean isGbdt) {
this(factorize, weak, factory, iterationsCount, step, false, ensembleSize, isGbdt);
}
public FMCBoosting(final Factorization factorize, final VecOptimization weak, final Class extends L2> factory, final int iterationsCount, final double step) {
this(factorize, weak, factory, iterationsCount, step, false, 1, false);
}
public FMCBoosting(Factorization factorize, final VecOptimization weak, final Class extends L2> factory, final int iterationsCount, final double step, final boolean lazyCursor, final int ensembleSize, final boolean isGbdt) {
this.factorize = factorize;
this.weak = weak;
this.factory = factory;
this.iterationsCount = iterationsCount;
this.step = step;
this.lazyCursor = lazyCursor;
this.ensembleSize = ensembleSize;
this.isGbdt = isGbdt;
}
@Override
public Ensemble fit(final VecDataSet learn, final BlockwiseMLLLogit target) {
final Vec[] B = new Vec[iterationsCount * ensembleSize];
final List weakModels = new ArrayList<>(iterationsCount * ensembleSize);
final List ensamble = new ArrayList<>(iterationsCount * ensembleSize);
final Mx cursor;
if (lazyCursor) {
cursor = new RowsVecArrayMx(new LazyGradientCursor(learn, weakModels, B, target, bds));
} else {
cursor = new RowsVecArrayMx(new GradientCursor(learn, weakModels, B, target, bds));
}
VecBasedMx validScore = null;
if (valid != null) {
validScore = new VecBasedMx(valid.length(), target.classesCount() - 1);
}
for (int t = 0; t < iterationsCount; t++) {
System.out.println("Iteration " + (t + 1));
final Pair factorize = this.factorize.factorize(cursor);
// TODO: remove extra parameters
for (int i = 0; i < ensembleSize; ++i) {
B[t * ensembleSize + i] = factorize.second;
}
final L2 globalLoss = DataTools.newTarget(factory, factorize.first, learn);
Interval.start();
for (int i = 0; i < ensembleSize; ++i) {
final ObliviousTree weakModel = (ObliviousTree) weak.fit(learn, DataTools.bootstrap(globalLoss, rfRnd));
if (bds == null) {
bds = learn.cache().cache(Binarize.class, VecDataSet.class).binarize(weakModel.grid());
}
if (this.isGbdt) {
for (int j = 0; j < learn.length(); ++j) {
factorize.first.adjust(j, -weakModel.value(this.bds, j));
}
}
// System.out.println(String.format("Vector norm: %.3f", VecTools.norm(factorize.first)));
weakModels.add(weakModel);
ensamble.add(new ScaledVectorFunc(weakModel, factorize.second));
}
Interval.stopAndPrint("Fitting greedy oblivious tree");
// Update valid score
if (valid != null) {
Interval.start();
for (int i = 0; i < ensembleSize; ++i) {
final int index = ensamble.size() - ensembleSize + i;
final ScaledVectorFunc func = ensamble.get(index);
for (int j = 0; j < valid.length(); ++j) {
VecTools.append(validScore.row(j), VecTools.scale(func.trans(valid.at(j)), -step));
}
}
double matches = 0;
for (int i = 0; i < valid.length(); ++i) {
double[] score = validScore.row(i).toArray();
int clazz = ArrayTools.max(score);
clazz = score[clazz] > 0 ? clazz : target.classesCount() - 1;
matches += (clazz == validTarget.label(i) ? 1 : 0);
}
final double accuracy = matches / valid.length();
if (bestIterCount == 0 || accuracy > bestAccuracy) {
bestIterCount = t + 1;
bestAccuracy = accuracy;
}
Interval.stopAndPrint("Evaluate valid accuracy");
System.out.println(String.format("Valid accuracy: %.4f", accuracy));
if (earlyStoppingRounds > 0 && t + 1 - bestIterCount == earlyStoppingRounds) {
// Early stopping
System.out.println("Early stopping!");
break;
}
}
/*
// Debug calculations
Vec u_pred = new ArrayVec(learn.length());
for (int i = 0; i < learn.length(); ++i) {
u_pred.set(i, weakModel.apply(learn.data().row(i)).get(0));
}
final Vec diff = VecTools.subtract(u_pred, factorize.first);
final double mae = VecTools.norm1(diff) / diff.dim();
for (int i = 0; i < diff.dim(); ++i) {
diff.set(i, diff.get(i) / factorize.first.get(i));
}
final double mape = Math.round(100.0 * VecTools.norm1(diff) / diff.dim());
System.out.println("Tree MAE: " + mae);
System.out.println("Tree MAPE: " + mape + "%");*/
invoke(new Ensemble<>(ensamble, -step));
}
if (valid != null) {
System.out.println(String.format(String.format("Best iterations count: %d", bestIterCount)));
System.out.println(String.format(String.format("Best valid accuracy: %.4f", bestAccuracy)));
return new Ensemble<>(ensamble.subList(0, ensembleSize * bestIterCount), -step);
}
return new Ensemble<>(ensamble, -step);
}
public void setEarlyStopping(final VecDataSet valid, final BlockwiseMLLLogit validTarget, final int earlyStoppingRounds) {
this.valid = valid;
this.validTarget = validTarget;
this.earlyStoppingRounds = earlyStoppingRounds;
}
private BiConsumer getLastWeakLearner(final Vec b, final ObliviousTree weakModel, BlockwiseMLLLogit target) {
final int classesCount = target.classesCount();
return (i, vec) -> {
final int pointClass = target.label(i);
final double scale = -step * weakModel.value(bds, i);
double S = 1;
for (int c = 0; c < classesCount - 1; c++) {
final double e = exp(b.get(c) * scale);
final double v = vec.get(c);
if (c == pointClass) {
S += (v + 1) * (e - 1);
vec.set(c, (v + 1) * e);
} else {
S += v * (e - 1);
vec.set(c, v * e);
}
}
for (int c = 0; c < classesCount - 1; c++) {
if (c == pointClass) {
vec.set(c, -1 + vec.get(c) / S);
} else {
vec.set(c, vec.get(c) / S);
}
}
};
}
private class GradientCursor extends Seq.Stub {
private final Mx cursor;
private final VecDataSet learn;
private final List weakModels;
private final BlockwiseMLLLogit target;
private final Vec[] b;
private final int[][] leafIndex;
private final double[][][] buffer;
private BinarizedDataSet bds;
private int size = 0;
public GradientCursor(VecDataSet learn, List weakModels, Vec[] b, BlockwiseMLLLogit target, BinarizedDataSet bds) {
this.cursor = new VecBasedMx(learn.data().rows(), target.classesCount() - 1);
this.learn = learn;
this.weakModels = weakModels;
this.target = target;
this.b = b;
this.bds = bds;
this.leafIndex = new int[ensembleSize][learn.length()];
this.buffer = new double[ensembleSize][][];
initCursor();
}
private void initCursor() {
for (int i = 0; i < learn.data().rows(); i++) {
for (int j = 0; j < target.classesCount() - 1; j++) {
cursor.adjust(i, j, 1.0 / target.classesCount());
if (j == target.label(i)) {
cursor.adjust(i, j, -1);
}
}
}
}
private void updateBuffer() {
final int size = weakModels.size();
final Vec b = this.b[size - 1];
final double step = FMCBoosting.this.step;
for (int tree = 0; tree < ensembleSize; ++tree) {
ObliviousTree weakModel = (ObliviousTree) weakModels.get(size - ensembleSize + tree);
List features = weakModel.features();
for (int index = 0; index < learn.length(); ++index) {
int leaf = 0;
for (int depth = 0; depth < features.size(); depth++) {
leaf <<= 1;
if (features.get(depth).value(bds.bins(features.get(depth).findex())[index]))
leaf++;
}
leafIndex[tree][index] = leaf;
}
final double[] values = weakModel.values();
if (buffer[tree] == null) {
buffer[tree] = new double[values.length][target.classesCount() - 1];
}
for (int i = 0; i < values.length; ++i) {
for (int j = 0; j < target.classesCount() - 1; ++j) {
buffer[tree][i][j] = FastMath.exp(-step * b.get(j) * values[i]);
}
}
}
}
private void updateCursor() {
final int size = weakModels.size();
final int classesCount = target.classesCount();
if (bds == null) {
bds = learn.cache().cache(Binarize.class, VecDataSet.class).binarize(((ObliviousTree) weakModels.get(size - 1)).grid());
}
long timeStart = System.currentTimeMillis();
updateBuffer();
System.out.println("updateBuffer: " + (System.currentTimeMillis() - timeStart) + " (ms)");
timeStart = System.currentTimeMillis();
IntStream.range(0, cursor.rows()).parallel().forEach(i -> {
final Vec vec = cursor.row(i);
final int pointClass = target.label(i);
double S = 1;
for (int c = 0; c < classesCount - 1; c++) {
double e = 1;
for (int t = 0; t < ensembleSize; ++t) {
e *= buffer[t][leafIndex[t][i]][c];
}
final double v = vec.get(c);
if (c == pointClass) {
S += (v + 1) * (e - 1);
vec.set(c, (v + 1) * e);
} else {
S += v * (e - 1);
vec.set(c, v * e);
}
}
for (int c = 0; c < classesCount - 1; c++) {
if (c == pointClass) {
vec.set(c, -1 + vec.get(c) / S);
} else {
vec.set(c, vec.get(c) / S);
}
}
});
System.out.println("Cursor update: " + (System.currentTimeMillis() - timeStart) + " (ms)");
this.size = size;
}
@Override
public Vec at(final int i) {
if (weakModels.size() != size) {
updateCursor();
}
return cursor.row(i);
}
@Override
public Seq sub(int start, int end) {
throw new UnsupportedOperationException();
}
@Override
public Seq sub(int[] indices) {
throw new UnsupportedOperationException();
}
@Override
public int length() {
return target.dim() / target.blockSize();
}
@Override
public boolean isImmutable() {
return true;
}
@Override
public Class extends Vec> elementType() {
return Vec.class;
}
@SuppressWarnings("unchecked")
@Override
public Stream stream() {
return IntStream.range(0, length()).mapToObj(this::at);
}
}
private class LazyGradientCursor extends Seq.Stub {
private final VecDataSet learn;
private final List weakModels;
private final BlockwiseMLLLogit target;
private final Vec[] b;
private BinarizedDataSet bds;
public LazyGradientCursor(VecDataSet learn, List weakModels, Vec[] b, BlockwiseMLLLogit target, BinarizedDataSet bds) {
this.learn = learn;
this.weakModels = weakModels;
this.target = target;
this.b = b;
this.bds = bds;
}
@Override
public Vec at(final int i) {
final int classesCount = target.classesCount();
final Vec H_t = new ArrayVec(classesCount - 1);
final List weakModels = this.weakModels;
final int size = weakModels.size();
final double step = -FMCBoosting.this.step;
if (size > 0 && weakModels.get(0) instanceof ObliviousTree) {
final ObliviousTree obliviousTree = (ObliviousTree) weakModels.get(0);
if (bds == null)
bds = learn.cache().cache(Binarize.class, VecDataSet.class).binarize(obliviousTree.grid());
final BinarizedDataSet bds = this.bds;
for (int j = 0; j < size; j++) {
final ObliviousTree tree = (ObliviousTree) weakModels.get(j);
VecTools.incscale(H_t, b[j], tree.value(bds, i) * step);
}
} else {
final Vec vec = learn.at(i);
for (int j = 0; j < size; j++) {
VecTools.incscale(H_t, b[j], weakModels.get(j).value(vec) * step);
}
}
final Vec result = new ArrayVec(classesCount - 1);
double sum = 0;
for (int c = 0; c < classesCount - 1; c++) {
final double expX = exp(H_t.get(c));
sum += expX;
}
final int pointClass = target.label(i);
for (int c = 0; c < classesCount - 1; c++) {
if (pointClass == c)
result.adjust(c, -(1. + sum - exp(H_t.get(c))) / (1. + sum));
else
result.adjust(c, exp(H_t.get(c)) / (1. + sum));
}
return result;
}
@Override
public Seq sub(int start, int end) {
throw new UnsupportedOperationException();
}
@Override
public Seq sub(int[] indices) {
throw new UnsupportedOperationException();
}
@Override
public int length() {
return target.dim() / target.blockSize();
}
@Override
public boolean isImmutable() {
return true;
}
@Override
public Class extends Vec> elementType() {
return Vec.class;
}
@SuppressWarnings("unchecked")
@Override
public Stream stream() {
return IntStream.range(0, length()).mapToObj(this::at);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy