com.expleague.ml.methods.linearRegressionExperiments.EmpericalBayesRidgeRegression Maven / Gradle / Ivy
package com.expleague.ml.methods.linearRegressionExperiments;
import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.func.Linear;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.ThreadTools;
import gnu.trove.iterator.TDoubleIterator;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;
/**
* Created by noxoomo on 10/06/15.
*/
public class EmpericalBayesRidgeRegression {
private final Mx[] datas;
private final double tolerance = 1e-14;
private final Vec[] targets;
private final int featuresCount;
private final EmpericalBayesRidgeRegressionCache[] cache;
private double alpha = 1e-12;
// private double beta = 1;
double diff = Double.POSITIVE_INFINITY;
private Vec[] means;
private static ThreadPoolExecutor exec = ThreadTools.createBGExecutor("bayesian linear model executor", -1);
public EmpericalBayesRidgeRegression(final Mx[] datas,final Vec[] targets) {
this.datas = datas;
this.targets = targets;
featuresCount = datas[0].columns();
for (int i = 1; i < datas.length; ++i) {
if (datas[i].columns() != featuresCount)
throw new IllegalArgumentException("tasks should use common set of features");
}
means = new Vec[datas.length];
cache = new EmpericalBayesRidgeRegressionCache[targets.length];
final CountDownLatch latch = new CountDownLatch(targets.length);
for (int i = 0; i < targets.length; ++i) {
final int ind = i;
exec.submit(new Runnable() {
@Override
public void run() {
cache[ind] = new EmpericalBayesRidgeRegressionCache(datas[ind], targets[ind]);
latch.countDown();
}
});
}
try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
Linear[] fit() {
int iter = 0;
while (diff > tolerance && iter++ < 30) {
fillMeans();
final double[] gammas = calcGammas();
final double gamma;
{
double res = 0;
for (double g : gammas)
res += g;
gamma = res;
}
double newAlpha = calcNewAlpha(gamma);
// final double newBeta = calcNewBeta(gamma);
final double newBetas[] = calcNewBetas(gammas);
// final double d = der(newAlpha);
// if (Math.abs(d) < tolerance)
// break;
// System.out.println("alpha: " + d);
// System.out.println("beta " + derBeta(newBetas));
diff = Math.abs(alpha - newAlpha);
for (int i = 0; i < cache.length; ++i) {
cache[i].update(newAlpha, newBetas[i]);
}
alpha = newAlpha;
// if (diff < tolerance || iter > 19) {
// for (double beta : newBetas) {
// System.out.println("Derivative " + der(newAlpha));
// System.out.println(newAlpha / beta);
// }
// }
// beta = newBetas;
}
fillMeans();
Linear[] result = new Linear[datas.length];
for (int i = 0; i < result.length; ++i) {
result[i] = new Linear(cache[i].getMean());
}
return result;
}
void fillMeans() {
for (int i = 0; i < cache.length; ++i)
means[i] = cache[i].getMean();
}
double[] calcGammas() {
double[] gamma = new double[cache.length];
for (int i = 0; i < cache.length; ++i) {
TDoubleIterator eigenValuesIterator = cache[i].getEigenvaluesIterator();
while (eigenValuesIterator.hasNext()) {
final double lambda = eigenValuesIterator.next();
gamma[i] += lambda / (lambda + alpha);
}
}
return gamma;
}
double der(double alpha) {
double der = 0;
for (int i = 0; i < cache.length; ++i) {
der += 0.5 * featuresCount / alpha;
der -= 0.5 * VecTools.multiply(means[i], means[i]);
TDoubleIterator eigenValuesIterator = cache[i].getEigenvaluesIterator();
while (eigenValuesIterator.hasNext()) {
final double lambda = eigenValuesIterator.next();
der -= 0.5 / (lambda + alpha);
}
}
return der;
}
double derBeta(double[] betas) {
double der = 0;
for (int i = 0; i < cache.length; ++i) {
der += 0.5 * datas[i].rows() / betas[i];
TDoubleIterator eigenValuesIterator = cache[i].getEigenvaluesIterator();
while (eigenValuesIterator.hasNext()) {
final double lambda = eigenValuesIterator.next();
der -= 0.5 * lambda / (lambda + alpha) / betas[i];
}
for (int j = 0; j < datas[i].rows(); ++j)
der -= 0.5 * (MathTools.sqr(targets[i].get(j) - VecTools.multiply(means[i], datas[i].row(j))));
}
return der;
}
double calcNewAlpha(final double gamma) {
double denum = 0;
for (int i = 0; i < cache.length; ++i) {
final Vec mean = means[i];
denum += VecTools.multiply(mean, mean);
}
return gamma / denum;
}
double[] calcNewBetas(final double[] gammas) {
final double[] betas = new double[gammas.length];
final CountDownLatch latch = new CountDownLatch(cache.length);
for (int i = 0; i < betas.length; ++i) {
final int ind = i;
exec.submit(new Runnable() {
@Override
public void run() {
double error = 0;
final Vec mean = means[ind];
final Mx data = datas[ind];
final Vec target = targets[ind];
for (int i = 0; i < data.rows(); ++i) {
error += MathTools.sqr(target.get(i) - VecTools.multiply(mean, data.row(i)));
}
// betas[ind] = (data.rows() - gammas[ind]) / error;
betas[ind] = (data.rows() - data.columns()) / error;
latch.countDown();
}
});
}
try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
return betas;
}
}
// t = Xw + \varepsilon
// w_i \sim N(0, 1.0 / alpha)
// \varepsilon \sim N(0,1.0 / beta)
// if alpha known — equals to ridge regression
//this is cache class, no VecOptimization
class EmpericalBayesRidgeRegressionCache {
private double alpha = 1e-15;
private double beta;
private final Mx sigma;
private final Vec eigenValues;
private Mx A; // cache for ( alpha I + beta Sigma)
private Mx invA; // cache for ( alpha I + beta Sigma)^-1
private final Vec covFeatureWithTarget;
public EmpericalBayesRidgeRegressionCache(Mx data, Vec target) {
sigma = new VecBasedMx(data.columns(), data.columns());
covFeatureWithTarget = new ArrayVec(data.columns());
for (int i = 0; i < data.columns(); ++i) {
final Vec feature = data.col(i);
sigma.set(i, i, VecTools.multiply(feature, feature));
covFeatureWithTarget.set(i, VecTools.multiply(feature, target));
for (int j = i + 1; j < data.columns(); ++j) {
final double cov = VecTools.multiply(feature, data.col(j));
sigma.set(i, j, cov);
sigma.set(j, i, cov);
}
}
A = new VecBasedMx(sigma.columns(), sigma.columns());
Mx eigenValuesMx = new VecBasedMx(sigma.columns(), sigma.columns());
Mx Q = new VecBasedMx(sigma.columns(), sigma.columns());
MxTools.eigenDecomposition(sigma, eigenValuesMx, Q);
this.eigenValues = new ArrayVec(sigma.columns());
for (int i = 0; i < sigma.columns(); ++i) {
this.eigenValues.set(i, eigenValuesMx.get(i, i));
}
beta = 1.0;
update(alpha, beta);
}
void update(double alpha, double beta) {
this.alpha = alpha;
this.beta = beta;
for (int i = 0; i < sigma.columns(); ++i) {
A.set(i, i, alpha + beta * sigma.get(i, i));
for (int j = i + 1; j < sigma.columns(); ++j) {
final double val = beta * sigma.get(i, j);
A.set(i, j, val);
A.set(j, i, val);
}
}
invA = MxTools.inverse(A);
}
Vec getMean() {
Vec result = MxTools.multiply(invA, covFeatureWithTarget);
return VecTools.scale(result, beta);
}
TDoubleIterator getEigenvaluesIterator() {
return new TDoubleIterator() {
int i = 0;
@Override
public double next() {
return beta * eigenValues.get(i++);
}
@Override
public boolean hasNext() {
return i < eigenValues.dim();
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy