All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.linearRegressionExperiments.RidgeRegression Maven / Gradle / Ivy

package com.expleague.ml.methods.linearRegressionExperiments;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.func.Linear;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.VecOptimization;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * Created by noxoomo on 10/06/15.
 */
public class RidgeRegression implements VecOptimization {
  final double alpha;

  final double multiplyArrayVec(Vec left, Vec right) {
    assert (left.dim() == right.dim());
    int size = (left.dim() / 4) * 4;
    double result = 0;
    for (int i = 0; i < size; i += 4) {
      final double la = left.get(i);
      final double lb = left.get(i + 1);
      final double lc = left.get(i + 2);
      final double ld = left.get(i + 3);
      final double ra = right.get(i);
      final double rb = right.get(i + 1);
      final double rc = right.get(i + 2);
      final double rd = right.get(i + 3);

      final double dpa = la * ra;
      final double dpb = lb * rb;
      final double dpc = lc * rc;
      final double dpd = ld * rd;

      result += (dpa + dpb) + (dpc + dpd);
    }
    for (int i = size; i < left.dim(); ++i) {
      result += left.get(i) * right.get(i);
    }
    return result;
  }

  public RidgeRegression(double alpha) {
    this.alpha = alpha;
  }

  @Override
  public Linear fit(VecDataSet learn, L2 l2) {
    Vec target = l2.target;
    Mx data = learn.data();
    return new Linear(fit(data, target));
  }

  static final ThreadPoolExecutor exec = ThreadTools.createBGExecutor("Ridge dot-products thread", -1);
  ;

  final public Vec fit(final Mx data, final Vec target) {
    final Mx cov = new VecBasedMx(data.columns(), data.columns());
    final Vec covTargetWithFeatures = new ArrayVec(data.columns());

    final CountDownLatch latch = new CountDownLatch(data.columns());

    for (int col = 0; col < data.columns(); ++col) {
      final int i = col;
      exec.submit((Runnable) () -> {
        final Vec feature = data.col(i);
        cov.set(i, i, multiplyArrayVec(feature, feature));
        cov.adjust(i, i, alpha);
        covTargetWithFeatures.set(i, multiplyArrayVec(feature, target));
        for (int j = i + 1; j < data.columns(); ++j) {
          final double val = multiplyArrayVec(feature, data.col(j));
          cov.set(i, j, val);
          cov.set(j, i, val);
        }
        latch.countDown();
      });
    }

    try {
      latch.await();
    } catch (InterruptedException e) {
      //
    }

    Mx invCov = MxTools.inverse(cov);
    return MxTools.multiply(invCov, covTargetWithFeatures);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy