
com.expleague.ml.methods.linearRegressionExperiments.RidgeRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jmll Show documentation
Show all versions of jmll Show documentation
Various ML methods implemented by myself and my students
package com.expleague.ml.methods.linearRegressionExperiments;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.func.Linear;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.VecOptimization;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;
/**
* Created by noxoomo on 10/06/15.
*/
public class RidgeRegression implements VecOptimization {
final double alpha;
final double multiplyArrayVec(Vec left, Vec right) {
assert (left.dim() == right.dim());
int size = (left.dim() / 4) * 4;
double result = 0;
for (int i = 0; i < size; i += 4) {
final double la = left.get(i);
final double lb = left.get(i + 1);
final double lc = left.get(i + 2);
final double ld = left.get(i + 3);
final double ra = right.get(i);
final double rb = right.get(i + 1);
final double rc = right.get(i + 2);
final double rd = right.get(i + 3);
final double dpa = la * ra;
final double dpb = lb * rb;
final double dpc = lc * rc;
final double dpd = ld * rd;
result += (dpa + dpb) + (dpc + dpd);
}
for (int i = size; i < left.dim(); ++i) {
result += left.get(i) * right.get(i);
}
return result;
}
public RidgeRegression(double alpha) {
this.alpha = alpha;
}
@Override
public Linear fit(VecDataSet learn, L2 l2) {
Vec target = l2.target;
Mx data = learn.data();
return new Linear(fit(data, target));
}
static final ThreadPoolExecutor exec = ThreadTools.createBGExecutor("Ridge dot-products thread", -1);
;
final public Vec fit(final Mx data, final Vec target) {
final Mx cov = new VecBasedMx(data.columns(), data.columns());
final Vec covTargetWithFeatures = new ArrayVec(data.columns());
final CountDownLatch latch = new CountDownLatch(data.columns());
for (int col = 0; col < data.columns(); ++col) {
final int i = col;
exec.submit((Runnable) () -> {
final Vec feature = data.col(i);
cov.set(i, i, multiplyArrayVec(feature, feature));
cov.adjust(i, i, alpha);
covTargetWithFeatures.set(i, multiplyArrayVec(feature, target));
for (int j = i + 1; j < data.columns(); ++j) {
final double val = multiplyArrayVec(feature, data.col(j));
cov.set(i, j, val);
cov.set(j, i, val);
}
latch.countDown();
});
}
try {
latch.await();
} catch (InterruptedException e) {
//
}
Mx invCov = MxTools.inverse(cov);
return MxTools.multiply(invCov, covTargetWithFeatures);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy