All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.ElasticNetMethod Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.methods;

import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.func.Linear;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.Trans;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.loss.L2;
import com.expleague.ml.models.ShifftedTrans;
import org.apache.commons.math3.util.FastMath;

import java.util.ArrayList;
import java.util.List;

import static com.expleague.commons.math.vectors.VecTools.adjust;
import static com.expleague.commons.math.vectors.VecTools.copy;

/**
 * User: noxoomo
 * Date: 5.02.2015
 * Time: 13:58
 * Pathwise coordinate descent (for more details see articles by Friedman, Hastie, Tibshirani)
 */

public class ElasticNetMethod extends VecOptimization.Stub {
  private final double tolerance;
  private double alpha;
  private double lambda;

  public ElasticNetMethod(final double tolerance, final double alpha, final double lambda) {
    this.tolerance = tolerance;
    this.lambda = lambda;
    this.alpha = alpha;
  }

  @Override
  public Trans fit(final VecDataSet ds, final L2 loss) {
    double intercept  = 0;
    Vec target = loss.target;
    for (int i=0; i < target.dim();++i) {
      intercept += target.get(i);
    }
    intercept /= target.dim();
    Vec localTarget = copy(target);
    adjust(localTarget,-intercept);
    final ElasticNetCache cache = new ElasticNetCache(ds.data(), localTarget, alpha, lambda);
    Trans result = fit(cache);
    return new ShifftedTrans(result,intercept);
  }

  public Trans fit(final VecDataSet ds, final L2 loss, final Vec init) {
    final ElasticNetCache cache = new ElasticNetCache(ds.data(), loss.target, init, alpha, lambda);
    return fit(cache);
  }

  public final List fit(final Mx data, final Vec target, int nlambda) {
    final ElasticNetCache cache = new ElasticNetCache(data, target, alpha, lambda);
    double lambdaMax = Double.NEGATIVE_INFINITY;
    for (int i=0; i < data.columns();++i) {
      lambdaMax = FastMath.max(FastMath.abs(cache.targetProduct(i)), lambdaMax);
    }
    lambdaMax *= 1.0  / (alpha *  data.rows());
    double lambdaMin = 0.0; //lambdaMax * lambdaEps;
    double step = (lambdaMax - lambdaMin) / nlambda;
    List path = new ArrayList<>(nlambda);

    for (double lambda = lambdaMax; lambda > lambdaMin; lambda -= step) {
      cache.setLambda(lambda);
      path.add(fit(cache));
    }
    return path;
  }

  public int checkIterations = 2;

  public Linear fit(ElasticNetCache cache) {
    boolean updated = true;
    Vec prev;
    Vec betas = cache.betas();
    while (updated) {
      updated = false;
      prev = betas;
      for (int i = 0; i < checkIterations; ++i) {
        for (int k = 0; k < cache.dim(); ++k) {
          updated = cache.update(k) || updated;
        }
        if (!updated)
          break;
      }
      betas = cache.betas();
      if (VecTools.distance(betas, prev) < tolerance) {
        break;
      }
    }
    return new Linear(betas);
  }


  public static class ElasticNetCache {
    private final Mx data;
    private final Vec target;
    private final double equalsTolerance = 1e-10;
    private final boolean[] isFeaturesProductCached;
    private final boolean[] isTargetCached;
    private final double[] gradient;
    private final double[] featureProducts;
    private final double[] targetProducts;
    private final Vec betas;
    private int dim;
    private double alpha;
    private double lambda;

    public ElasticNetCache(final Mx data, final Vec target, final Vec init, int dim, double alpha, double lambda) {
      this.alpha = alpha;
      this.lambda = lambda;
      this.data = data;
      this.target = target;
      this.betas = init;
      this.dim = 0;
      isFeaturesProductCached = new boolean[betas.dim()*betas.dim()];
      isTargetCached = new boolean[betas.dim()];
      featureProducts = new double[betas.dim() * betas.dim()];
      targetProducts = new double[betas.dim()];
      gradient = new double[betas.dim()];
      this.updateDim(dim);
    }

    public ElasticNetCache(final Mx data, final Vec target, final Vec init, double alpha, double lambda) {
      this(data, target, init, init.dim(), alpha, lambda);
    }

    public ElasticNetCache(final Mx data, final Vec target, double alpha, double lambda) {
      this(data, target, new ArrayVec(data.columns()), alpha, lambda);
    }

    public ElasticNetCache(final Mx data, final Vec target,int dim, double alpha, double lambda) {
      this(data, target, new ArrayVec(data.columns()),dim, alpha, lambda);
    }


    public double beta(int i) {
      return betas.get(i);
    }

    public int dim() {
      return dim;
    }

    public void updateDim(int newDim) {
      final int oldDim = dim;
      dim = newDim;
      for (int i = oldDim; i < dim; ++i) {
        double res = targetProduct(i);
        for (int j = 0; j < i; ++j) {
          final double beta = betas.get(j);
          res -= beta != 0 ? beta * featureProduct(j,i) : 0;
        }
        for (int j = i + 1; j < dim; ++j) {
          final double beta = betas.get(j);
          res -= beta !=0 ? beta * featureProduct(i, j) : 0;
        }
        gradient[i] = res;
      }
      for (int i=0; i < oldDim;++i)  {
        for (int j=oldDim; j < dim;++j) {
          final double beta = betas.get(j);
          gradient[i] -= beta != 0 ? beta * featureProduct(i, j) : 0;
        }
      }
    }

    public double gradient(int k) {
      return gradient[k];
    }



    private double dot(Mx data, int i, int j) {
      final int rows = data.rows();
      final int length = 4*(rows / 4);
      double result = 0;
      final double[] cache = new double[4];
      for (int k=0; k < length; k+=4) {
        final double l1 = data.get(k,i);
        final double l2 = data.get(k+1,i);
        final double l3 = data.get(k+2,i);
        final double l4 = data.get(k+3,i);

        final double r1 = data.get(k,j);
        final double r2 = data.get(k+1,j);
        final double r3 = data.get(k+2,j);
        final double r4 = data.get(k+3,j);

        cache[0] = l1 * r1;
        cache[1] = l2 * r2;
        cache[2] = l3 * r3;
        cache[3] = l4 * r4;
        cache[0] += cache[2];
        cache[1] += cache[3];
        cache[0] += cache[1];
        result += cache[0];
      }
      for (int k=length; k < rows;++k) {
        result += data.get(k,i) * data.get(k,j);
      }
      return result;
    }
//jvm vectorization http://hg.openjdk.java.net/hsx/hotspot-comp/hotspot/rev/006050192a5a
    private double targetDot(Mx data, int i, Vec target) {
      final int rows = data.rows();
      final int length = 4*(rows / 4);
      double result = 0;
      final double[] cache = new double[4];
      for (int k=0; k < length; k+=4) {
        final double l1 = data.get(k,i);
        final double l2 = data.get(k+1,i);
        final double l3 = data.get(k+2,i);
        final double l4 = data.get(k+3,i);

        final double r1 = target.get(k);
        final double r2 = target.get(k+1);
        final double r3 = target.get(k+2);
        final double r4 = target.get(k+3);

        cache[0] = l1 * r1;
        cache[1] = l2 * r2;
        cache[2] = l3 * r3;
        cache[3] = l4 * r4;
        cache[0] += cache[2];
        cache[1] += cache[3];
        cache[0] += cache[1];
        result += cache[0];
      }
      for (int k=length; k < rows;++k) {
        result += data.get(k,i) * target.get(k);
      }
      return result;
    }

    private double featureProduct(int i, int j) {
      if (i > j) {
        return featureProduct(j, i);
      }
      if (!isFeaturesProductCached[i*betas.dim() + j]) {
        featureProducts[i*betas.dim() + j] = dot(data, i, j);
        isFeaturesProductCached[i*betas.dim() + j] = true;
      }
      return featureProducts[i*betas.dim() + j];
    }

    private double targetProduct(int k) {
      if (!isTargetCached[k]) {
        targetProducts[k] = targetDot(data, k, target);
        isTargetCached[k] = true;
      }
      return targetProducts[k];
    }

    public void setLambda(double lambda) {
      this.lambda = lambda;
    }

    public void setAlpha(double alpha) {
      this.alpha = alpha;
    }

    public boolean update(int k) {
      final int N = data.rows();
      double newBeta = gradient(k);
      newBeta = softThreshold(newBeta, N * lambda * alpha);
      newBeta /= (featureProduct(k, k) + N * lambda * (1 - alpha));
      if (Math.abs(newBeta - betas.get(k)) > equalsTolerance) {
        update(k, newBeta);
        return true;
      }
      return false;
    }

    private void update(final int k,final double newBeta) {
      final double beta = betas.get(k);
      final double diff = newBeta - beta;
      {
        final int length = 4 * (k / 4);
        final double[] gradientLocal = gradient;
        for (int i = 0; i < length; i += 4) {
          final int ind = i;
          final int localK = k;
          final double dot1 = diff * featureProduct(ind,localK);
          final double dot2 = diff * featureProduct(ind + 1,localK);
          final double dot3 = diff * featureProduct(ind + 2,localK);
          final double dot4 = diff * featureProduct(ind + 3,localK);
          gradientLocal[ind] -= dot1;
          gradientLocal[ind + 1] -= dot2;
          gradientLocal[ind + 2] -= dot3;
          gradientLocal[ind + 3] -= dot4;
        }
        for (int i = length; i < k; ++i) {
          gradientLocal[i] -= diff * featureProduct(i,k);
        }
      }

      {
        final int offset = k +1;
        final int size = dim - offset;
        final int end = 4 * (size / 4) + offset;
        final double[] gradientLocal = gradient;
        for (int i = offset; i < end; i += 4) {
          final int ind = i;
          final int localK = k;
          final double dot1 = diff * featureProduct(localK, ind);
          final double dot2 = diff * featureProduct(localK, ind + 1);
          final double dot3 = diff * featureProduct(localK, ind + 2);
          final double dot4 = diff * featureProduct(localK, ind + 3);
          gradientLocal[ind] -= dot1;
          gradientLocal[ind + 1] -= dot2;
          gradientLocal[ind + 2] -= dot3;
          gradientLocal[ind + 3] -= dot4;
        }
        for (int i = end; i < dim; ++i) {
          gradientLocal[i] -= diff * featureProduct(k,i);
        }
      }
      betas.set(k, newBeta);
    }

    public Vec betas() {
      return copy(betas);
    }

    private double softThreshold(final double z, final double j) {
      final double sgn = Math.signum(z);
      return sgn * Math.max(sgn * z - j, 0);
    }
  }

}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy