com.expleague.ml.methods.trees.GreedyExponentialObliviousTree Maven / Gradle / Ivy
package com.expleague.ml.methods.trees;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.BFGrid;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.methods.greedyRegion.GreedyPolynomialExponentRegion;
import com.expleague.ml.models.ObliviousTree;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.ml.models.ExponentialObliviousTree;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.util.List;
/*Created with IntelliJ IDEA.
*User:towelenee
*Date:30.11.13
*Time:17:48
*Idea please stop making my code yellow
*/
public class GreedyExponentialObliviousTree extends VecOptimization.Stub {
private final int numberOfVariablesByLeaf;
private final int numberOfVariables;
private double[][][] quadraticMissCoefficient;
private double[][] linearMissCoefficient;
private final double DistCoef;
private final int depth;
private final GreedyObliviousTree got;
private List features;
public GreedyExponentialObliviousTree(final BFGrid grid, final int depth, final double distCoef) {
got = new GreedyObliviousTree(grid, depth);
DistCoef = distCoef;
this.depth = depth;
numberOfVariablesByLeaf = (depth + 1) * (depth + 2) / 2;
numberOfVariables = (1 << depth) * numberOfVariablesByLeaf;
}
public int getIndex(final int mask, int i, int j) {
if (i < j) {
final int temp = i;
i = j;
j = temp;
}
return mask * (depth + 1) * (depth + 2) / 2 + i * (i + 1) / 2 + j;
}
double sqr(final double x) {
return x * x;
}
double calcDistanseToRegion(final int index, final Vec point) {
double ans = 0;
for (int i = 0; i < features.size(); i++) {
if (features.get(i).value(point) != ((index >> i) == 1)) {
ans += sqr(point.get(features.get(i).findex()) - features.get(i).condition());//L2
}
}
return DistCoef * ans;
}
void precalculateMissCoefficients(final DataSet ds, final L2 loss) {
quadraticMissCoefficient = new double[1 << depth][numberOfVariablesByLeaf][numberOfVariablesByLeaf];
linearMissCoefficient = new double[1 << depth][numberOfVariablesByLeaf];
for (int i = 0; i < ds.length(); i++) {
final double[] data = new double[depth + 1];
data[0] = 1;
for (int s = 0; s < features.size(); s++) {
data[s + 1] = ((VecDataSet) ds).data().get(i, features.get(s).findex());
}
int index = 0;
for (int j = 0; j < features.size(); j++) {
index <<= 1;
if (features.get(j).value(((VecDataSet) ds).data().row(i)))
index++;
}
//if(index == 1)
// System.out.println(lines.at(0).condition);
final double f = loss.target.get(i);
final double weight = 1; //Math.exp(-calcDistanseToRegion(index, ds.data().row(i)));
//System.out.println(weight);
for (int x = 0; x <= depth; x++)
for (int y = 0; y <= x; y++) {
linearMissCoefficient[index][getIndex(0, x, y)] -= 2 * f * data[x] * data[y] * weight;
}
for (int x = 0; x <= depth; x++)
for (int y = 0; y <= x; y++)
for (int x1 = 0; x1 <= depth; x1++)
for (int y1 = 0; y1 <= x1; y1++)
quadraticMissCoefficient[index][getIndex(0, x, y)][getIndex(0, x1, y1)] += data[x] * data[y] * data[x1] * data[y1] * weight;
}
}
@Override
public ExponentialObliviousTree fit(final VecDataSet ds, final L2 loss) {
final ObliviousTree base = got.fit(ds, loss);
features = base.features();
double baseMse = 0;
for (int i = 0; i < ds.length(); i++)
baseMse += sqr(base.value(ds.data().row(i)) - loss.target.get(i));
System.out.println("\nBase_MSE = " + baseMse);
if (features.size() != depth) {
System.out.println("Oblivious Tree bug");
try {
final PrintWriter printWriter = new PrintWriter(new File("badloss.txt"));
for (int i = 0; i < ds.length(); i++)
printWriter.println(loss.target.get(i));
printWriter.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
}
System.exit(-1);
}
precalculateMissCoefficients(ds, loss);
//System.out.println("Precalc is over");
final double[][] out = new double[1 << depth][(depth + 1) * (depth + 2) / 2];
for (int index = 0; index < 1 << depth; index++) {
final Mx a = new VecBasedMx(numberOfVariablesByLeaf, numberOfVariablesByLeaf);
final Vec b = new ArrayVec(numberOfVariablesByLeaf);
for (int i = 0; i < numberOfVariablesByLeaf; i++)
b.set(i, -linearMissCoefficient[index][i]);
for (int i = 0; i < numberOfVariablesByLeaf; i++)
for (int j = 0; j < numberOfVariablesByLeaf; j++)
a.set(i, j, quadraticMissCoefficient[index][i][j]);
for (int i = 0; i < numberOfVariablesByLeaf; i++)
a.adjust(i, i, 1e-1);
final Vec value = GreedyPolynomialExponentRegion.solveLinearEquationUsingLQ(a, b);
//System.out.println(a);
for (int k = 0; k <= depth; k++)
for (int j = 0; j <= k; j++)
out[index][k * (k + 1) / 2 + j] = value.get(getIndex(0, k, j));
/*if(quadraticMissCoefficient[index][0][0] != 0)
out[index][0] = linearMissCoefficient[index][0] / quadraticMissCoefficient[index][0][0];*/
//out[index][0] = base.values()[index];
//for (int i = 0; i < out[index].length; i++)
//System.out.println(out[index][i]);
}
//for(int i =0 ; i < gradLambdas.size();i++)
// System.out.println(serializeCondtion(i));
final ExponentialObliviousTree ret = new ExponentialObliviousTree(features, out, DistCoef);
double mse = 0;
for (int i = 0; i < ds.length(); i++)
mse += sqr(ret.value(ds.data().row(i)) - loss.target.get(i));
System.out.println("MSE = " + mse);
/*if (mse > baseMse + 1e-5)
try {
throw new Exception("Bad model work mse of based model less than mse of extended model");
} catch (Exception e) {
e.printStackTrace();
//System.exit(-1);
}*/
return ret;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy