com.expleague.ml.methods.rvm.RVMCache Maven / Gradle / Ivy
package com.expleague.ml.methods.rvm;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.func.BiasedLinear;
import com.expleague.commons.math.stat.StatTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.util.ArrayTools;
import gnu.trove.iterator.TIntIterator;
import org.apache.commons.math3.util.FastMath;
/**
* Created by noxoomo on 01/06/15.
*/
//Relevance vector machine model
// target = Xw + \varepsilon, \varepsilon \sim N(0,\sigma^2), w_i \sim N(0, 1/a_i)
// X — set of basis functions
// see bishop pattern recognition and machine learning or tipping articles for details
class RVMCache {
final Mx data;
final Vec target;
final DotProductsCache featureProducts;
final private double[] alpha; //precision prior for weights
final private double[] q;
final private double[] s;
final private double[] theta;
final private double[] diffs;
double noiseVariance;
final ActiveIndicesSet activeIndices;
private Mx sigma;
private Vec mu;
private Vec predictions;
RVMCache(final Mx data, final Vec target, final FastRandom random) {
this.data = data;
this.target = target;
this.predictions = new ArrayVec(target.dim());
this.alpha = new double[data.columns() + 1]; //last component for bias
this.s = new double[data.columns() + 1];
this.q = new double[data.columns() + 1];
this.theta = new double[data.columns() + 1];
this.diffs = new double[data.columns() + 1];
noiseVariance = StatTools.variance(target) / 4;
final double w = target.dim();
alpha[data.columns()] = w / (VecTools.sum2(target) / w - noiseVariance);
featureProducts = new DotProductsCache(data, target);
activeIndices = new ActiveIndicesSet(this.alpha.length, random);
activeIndices.addToActive(data.columns());
for (int i = 0; i < data.columns(); ++i)
alpha[i] = Double.POSITIVE_INFINITY;
estimateVariance();
}
void estimateVariance() {
calcSigma();
calcMean();
updatePredictions();
noiseVariance = calcNoiseVariance();
updateQuantities();
}
private double calcNoiseVariance() {
double denum = target.dim() - activeIndices.size();
TIntIterator indices = activeIndices.activeIterator();
{
int i = 0;
while (indices.hasNext()) {
int index = indices.next();
denum += alpha[index] * sigma.get(i++);
}
}
return VecTools.distanceL2(predictions, target) / denum;
}
void calcSigma() {
Mx invSigma = new VecBasedMx(activeIndices.size(), activeIndices.size());
int[] indices = activeIndices.activeIndices();
for (int i = 0; i < indices.length; ++i) {
final int firstIndex = indices[i];
invSigma.adjust(i, i, alpha[firstIndex]);
invSigma.adjust(i, i, featureProducts.featuresProduct(firstIndex, firstIndex) / noiseVariance);
for (int j = i + 1; j < indices.length; ++j) {
final int secondIndex = indices[j];
final double val = featureProducts.featuresProduct(firstIndex, secondIndex) / noiseVariance;
invSigma.adjust(i, j, val);
invSigma.adjust(j, i, val);
}
}
sigma = MxTools.inverse(invSigma);
}
void calcMean() {
Vec result = new ArrayVec(activeIndices.size());
TIntIterator index = activeIndices.activeIterator();
int i = 0;
while (index.hasNext()) {
result.adjust(i++, featureProducts.targetProducts(index.next()));
}
result = MxTools.multiply(sigma, result);
VecTools.scale(result, 1.0 / noiseVariance);
mu = result;
}
void updatePredictions() {
for (int point = 0; point < target.dim(); ++point) {
TIntIterator indices = activeIndices.activeIterator();
int i = 0;
double result = 0;
while (indices.hasNext()) {
final int ind = indices.next();
result += ind < data.columns() ? mu.get(i++) * data.get(point, ind) : mu.get(i++);
}
predictions.set(point, result);
}
}
void updateQuantities() {
calcSigma();
calcMean();
for (int i = 0; i < alpha.length; ++i) {
updateQuantities(i);
}
// {
// Mx A = new VecBasedMx(alpha.length,alpha.length);
// for (int i=0; i < alpha.length;++i) {
// A.set(i,i,1.0 / alpha[i]);
// }
//
// Mx C = MxTools.multiply(MxTools.multiply(F,A),trF);
// assert(C.columns() == C.rows());
// for (int i=0; i < C.columns();++i) {
// C.adjust(i,i,noiseVariance);
// }
//
//// Vec S = new ArrayVec(F.columns());
//// Vec Q = new ArrayVec(F.columns());
// Mx invC = MxTools.inverseCholesky(C);
// Vec cTar = MxTools.multiply(invC, target);
// for (int i=0; i < F.columns();++i){
// s[i] = VecTools.multiply(F.col(i), MxTools.multiply(invC, F.col(i)));
// q[i] = VecTools.multiply(F.col(i), cTar);
// }
//
//// for (int i=0; i < S.xdim();++i) {
//// s[i] = S.get(i);
//// q[i] = Q.get(i);
////// if (Math.abs(S.get(i)- s[i]) > 1e-2) {
////// System.out.println(i);
////// }
//////// assert(Math.abs(S.get(i)- s[i]) < 1e-3);
//////
////// if (Math.abs(Q.get(i)- q[i]) > 1e-2) {
////// System.out.println(i);
////// }
////// assert(Math.abs(Q.get(i)- q[i]) < 1e-3);
//// }
// }
}
private void updateQuantities(int feature) {
Vec dotFeatureWithActiveFeatures = new ArrayVec(activeIndices.size());
{
TIntIterator activeFeatures = activeIndices.activeIterator();
int i = 0;
while (activeFeatures.hasNext()) {
final int ind = activeFeatures.next();
final double res = featureProducts.featuresProduct(feature, ind) / noiseVariance;
dotFeatureWithActiveFeatures.set(i++, res);
}
}
s[feature] = featureProducts.featuresProduct(feature, feature) / noiseVariance
- VecTools.multiply(dotFeatureWithActiveFeatures, MxTools.multiply(sigma, dotFeatureWithActiveFeatures));
q[feature] = featureProducts.targetProducts(feature) / noiseVariance -
VecTools.multiply(dotFeatureWithActiveFeatures, mu);
}
enum Result {
Remove,
Add,
Updated,
Skipped
}
public Result update(int i) {
final double si = Double.isInfinite(alpha[i]) ? s[i] : alpha[i] * s[i] / (alpha[i] - s[i]);
final double qi = Double.isInfinite(alpha[i]) ? q[i] : alpha[i] * q[i] / (alpha[i] - s[i]);
theta[i] = qi * qi - si;
if (theta[i] > 0) {
final double oldAlpha = alpha[i];
alpha[i] = si * si / theta[i];
if (Double.isInfinite(oldAlpha)) {
activeIndices.addToActive(i);
diffs[i] = Double.POSITIVE_INFINITY;
return Result.Add;
} else {
diffs[i] = FastMath.abs(Math.log(oldAlpha) - Math.log(alpha[i]));
return Result.Updated;
}
} else if (theta[i] < 0 && Math.abs(alpha[i]) <= Double.MAX_VALUE) {
alpha[i] = Double.POSITIVE_INFINITY;
diffs[i] = 0;
activeIndices.removeFromActive(i);
return Result.Remove;
}
return Result.Skipped;
}
private boolean stop(double tolerance) {
for (int i = 0; i < diffs.length; ++i) {
if (diffs[i] > tolerance || (Double.isInfinite(alpha[i]) && theta[i] > 1e-3))
return false;
}
return true;
}
public BiasedLinear fit(double tolerance) {
do {
TIntIterator nextFeature = activeIndices.indicesIterator();
while (nextFeature.hasNext()) {
if (update(nextFeature.next()) != Result.Skipped)
updateQuantities();
}
estimateVariance();
} while (!stop(tolerance));
final double[] weights = new double[data.columns()];
double bias = 0;
{
TIntIterator active = activeIndices.activeIterator();
int i = 0;
while (active.hasNext()) {
final int ind = active.next();
if (ind != data.columns()) {
weights[ind] = mu.get(i++);
} else {
bias = mu.get(i++);
}
}
}
return new BiasedLinear(weights, bias);
}
//lazy cache for dot products
//TODO: merge with elastic net
//TODO: reduce x2 memory usage
//not thread safe
static class DotProductsCache {
final Vec target;
final Mx data;
private final boolean[] isFeaturesProductCached;
private final boolean[] isTargetCached;
private final Mx featureProducts;
private final Vec targetProducts;
public DotProductsCache(final Mx data, final Vec target) {
this.target = target;
this.data = data;
this.isFeaturesProductCached = new boolean[(data.columns() + 1) * (data.columns() + 1)];
this.isTargetCached = new boolean[data.columns() + 1];
this.featureProducts = new VecBasedMx(data.columns() + 1, data.columns() + 1);
this.targetProducts = new ArrayVec(data.columns() + 1);
}
public double featuresProduct(int i, int j) {
if (isFeaturesProductCached[i * featureProducts.columns() + j]) {
return featureProducts.get(i, j);
} else {
final double v;
if (i != data.columns() && j != data.columns()) {
v = VecTools.multiply(data.col(i), data.col(j));
} else {
if (i == j) {
v = data.rows();
} else {
int ind = i < j ? i : j;
v = VecTools.sum(data.col(ind));
}
}
featureProducts.set(i, j, v);
featureProducts.set(j, i, v);
isFeaturesProductCached[i * featureProducts.columns() + j] = true;
isFeaturesProductCached[j * featureProducts.columns() + i] = true;
return v;
}
}
public double targetProducts(int i) {
if (isTargetCached[i]) {
return targetProducts.get(i);
} else {
final double v;
if (i != data.columns()) {
v = VecTools.multiply(data.col(i), target);
} else {
v = VecTools.sum(target);
}
targetProducts.set(i, v);
isTargetCached[i] = true;
return v;
}
}
}
static class ActiveIndicesSet {
private final int[] set;
private final int[] indicesMap;
private int cursor;
private final FastRandom random;
public ActiveIndicesSet(int size, FastRandom rand) {
set = ArrayTools.sequence(0, size);
cursor = 0;
indicesMap = ArrayTools.sequence(0, size);
this.random = rand;
}
public TIntIterator indicesIterator() {
return new TIntIterator() {
int current = 0;
@Override
public int next() {
// return current++;
++current;
return random.nextInt(set.length);
}
@Override
public boolean hasNext() {
return current < 10*set.length;
}
@Override
public void remove() {
if (indicesMap[current] < cursor) {
removeFromActive(current);
}
}
};
}
public void addToActive(int index) {
if (indicesMap[index] < cursor || index >= indicesMap.length) {
throw new IllegalArgumentException("already active index");
}
swapWithCursor(index);
++cursor;
}
public int size() {
return cursor;
}
private void swapWithCursor(int index) {
final int movedInd = set[cursor];
final int movedTo = indicesMap[index];
indicesMap[index] = cursor;
set[cursor] = index;
indicesMap[movedInd] = movedTo;
set[movedTo] = movedInd;
}
public void removeFromActive(int index) {
--cursor;
swapWithCursor(index);
}
TIntIterator activeIterator() {
return new TIntIterator() {
int current = 0;
@Override
public int next() {
return set[current++];
}
@Override
public boolean hasNext() {
return current < cursor;
}
@Override
public void remove() {
throw new UnsupportedOperationException("unsupported");
}
};
}
int[] activeIndices() {
int[] result = new int[cursor];
System.arraycopy(set, 0, result, 0, result.length);
return result;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy