All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.linearRegressionExperiments.ValidationRidgeRegression Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.methods.linearRegressionExperiments;

import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.func.Linear;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.commons.math.stat.StatTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import gnu.trove.list.array.TIntArrayList;

import static com.expleague.commons.math.MathTools.sqr;

/**
 * Created by noxoomo on 10/06/15.
 */
public class ValidationRidgeRegression implements VecOptimization {
  final double validation;
  final FastRandom rand;
  final private double minLambda = 1e-1;

  public ValidationRidgeRegression(double validationPart, FastRandom rand) {
    this.validation = validationPart;
    this.rand = rand;
  }

  @Override
  public Linear fit(VecDataSet learn, L2 l2) {
    FastRandom random = new FastRandom(rand.nextLong()); //for parallel fit

    Mx data = learn.data();
    TIntArrayList learnPoints = new TIntArrayList();
    TIntArrayList validationPoints = new TIntArrayList();
    for (int i = 0; i < data.rows(); ++i) {
      if (random.nextDouble() < validation) {
        validationPoints.add(i);
      } else {
        learnPoints.add(i);
      }
    }

    Vec target = l2.target();
    double variance = StatTools.variance(target);

    Mx cov = new VecBasedMx(data.columns(), data.columns());
    Vec covTargetWithFeatures = new ArrayVec(data.columns());


    for (int i = 0; i < data.columns(); ++i) {
      final Vec feature = data.col(i);
      cov.set(i, i, multiply(feature, feature, learnPoints));
      covTargetWithFeatures.set(i, multiply(feature, target, learnPoints));
      for (int j = i + 1; j < data.columns(); ++j) {
        final double val = multiply(feature, data.col(j), learnPoints);
        cov.set(i, j, val);
        cov.set(j, i, val);
      }
    }

    RidgeRegressionCache ridge = new RidgeRegressionCache(cov, covTargetWithFeatures);
    double bestScore = variance;
    double lambda = minLambda;
    double bestLambda = lambda;
    while (true) {
      final Linear model = ridge.fit(lambda);
      final double score = score(model, data, target, validationPoints);
      if (score > bestScore) {
        break;
      }
      bestLambda = lambda;
      lambda *= 2;
      bestScore = score;
      if (lambda > 1) {
        return new Linear(new double[data.columns()]);
      }
    }
    if (bestScore <= variance) {
      return new Linear(new double[data.columns()]);
    }

    for (int i = 0; i < data.columns(); ++i) {
      final Vec feature = data.col(i);
      cov.adjust(i, i, multiply(feature, feature, validationPoints));
      covTargetWithFeatures.adjust(i, multiply(feature, target, validationPoints));
      for (int j = i + 1; j < data.columns(); ++j) {
        final double val = multiply(feature, data.col(j), validationPoints);
        cov.adjust(i, j, val);
        cov.adjust(j, i, val);
      }
    }

    Linear result = ridge.fit(bestLambda);
    learnPoints.addAll(validationPoints);
    double resultScore = score(result, data, target, learnPoints);
    if (resultScore > variance) {
      return new Linear(new double[data.columns()]);
    }
    return result;
  }

  private double score(Linear model, Mx data, Vec target, TIntArrayList points) {
    double score = 0;
    for (int i = 0; i < points.size(); ++i) {
      final int point = points.get(i);
      final double diff = MathTools.sqr(model.value(data.row(point)) - target.get(point));
      score += diff;
    }
    return score / (points.size() - model.dim());
  }

  private double multiply(Vec left, Vec right, TIntArrayList points) {
    double res = 0;
    for (int i = 0; i < points.size(); ++i) {
      final int ind = points.get(i);
      res += left.get(ind) * right.get(ind);
    }
    return res;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy