com.expleague.ml.methods.linearRegressionExperiments.ValidationRidgeRegression Maven / Gradle / Ivy
package com.expleague.ml.methods.linearRegressionExperiments;
import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.func.Linear;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.commons.math.stat.StatTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import gnu.trove.list.array.TIntArrayList;
import static com.expleague.commons.math.MathTools.sqr;
/**
* Created by noxoomo on 10/06/15.
*/
public class ValidationRidgeRegression implements VecOptimization {
final double validation;
final FastRandom rand;
final private double minLambda = 1e-1;
public ValidationRidgeRegression(double validationPart, FastRandom rand) {
this.validation = validationPart;
this.rand = rand;
}
@Override
public Linear fit(VecDataSet learn, L2 l2) {
FastRandom random = new FastRandom(rand.nextLong()); //for parallel fit
Mx data = learn.data();
TIntArrayList learnPoints = new TIntArrayList();
TIntArrayList validationPoints = new TIntArrayList();
for (int i = 0; i < data.rows(); ++i) {
if (random.nextDouble() < validation) {
validationPoints.add(i);
} else {
learnPoints.add(i);
}
}
Vec target = l2.target();
double variance = StatTools.variance(target);
Mx cov = new VecBasedMx(data.columns(), data.columns());
Vec covTargetWithFeatures = new ArrayVec(data.columns());
for (int i = 0; i < data.columns(); ++i) {
final Vec feature = data.col(i);
cov.set(i, i, multiply(feature, feature, learnPoints));
covTargetWithFeatures.set(i, multiply(feature, target, learnPoints));
for (int j = i + 1; j < data.columns(); ++j) {
final double val = multiply(feature, data.col(j), learnPoints);
cov.set(i, j, val);
cov.set(j, i, val);
}
}
RidgeRegressionCache ridge = new RidgeRegressionCache(cov, covTargetWithFeatures);
double bestScore = variance;
double lambda = minLambda;
double bestLambda = lambda;
while (true) {
final Linear model = ridge.fit(lambda);
final double score = score(model, data, target, validationPoints);
if (score > bestScore) {
break;
}
bestLambda = lambda;
lambda *= 2;
bestScore = score;
if (lambda > 1) {
return new Linear(new double[data.columns()]);
}
}
if (bestScore <= variance) {
return new Linear(new double[data.columns()]);
}
for (int i = 0; i < data.columns(); ++i) {
final Vec feature = data.col(i);
cov.adjust(i, i, multiply(feature, feature, validationPoints));
covTargetWithFeatures.adjust(i, multiply(feature, target, validationPoints));
for (int j = i + 1; j < data.columns(); ++j) {
final double val = multiply(feature, data.col(j), validationPoints);
cov.adjust(i, j, val);
cov.adjust(j, i, val);
}
}
Linear result = ridge.fit(bestLambda);
learnPoints.addAll(validationPoints);
double resultScore = score(result, data, target, learnPoints);
if (resultScore > variance) {
return new Linear(new double[data.columns()]);
}
return result;
}
private double score(Linear model, Mx data, Vec target, TIntArrayList points) {
double score = 0;
for (int i = 0; i < points.size(); ++i) {
final int point = points.get(i);
final double diff = MathTools.sqr(model.value(data.row(point)) - target.get(point));
score += diff;
}
return score / (points.size() - model.dim());
}
private double multiply(Vec left, Vec right, TIntArrayList points) {
double res = 0;
for (int i = 0; i < points.size(); ++i) {
final int ind = points.get(i);
res += left.get(ind) * right.get(ind);
}
return res;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy