All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.LassoRegionsForest Maven / Gradle / Ivy

package com.expleague.ml.methods;

import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.set.impl.VecDataSetImpl;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.func.Linear;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.greedyRegion.BinaryRegion;
import com.expleague.ml.methods.greedyRegion.RegionBasedOptimization;
import com.expleague.ml.models.Region;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.func.Ensemble;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;

public class LassoRegionsForest extends WeakListenerHolderImpl implements VecOptimization {
  protected final FastRandom rnd;
  private final int count;
  private final RegionBasedOptimization> weak;
  private double lambda;
  private final double alpha;
  private final double tolerance = 1e-5;

  public LassoRegionsForest(RegionBasedOptimization> weak, FastRandom rnd,
                            final int count, final double lambda, final double alpha) {
    this.count = count;
    this.rnd = rnd;
    this.weak = new BinaryRegion<>(weak);
    this.lambda = lambda;
    this.alpha = alpha;
  }

  public LassoRegionsForest(RegionBasedOptimization> weak, FastRandom rnd, final int count) {
    this(weak, rnd, count, 1e-3, 1.0);
  }

  private static final ThreadPoolExecutor exec = ThreadTools.createBGExecutor("Lasso forest thread", -1);

  @Override
  public Trans fit(final VecDataSet learn, final Loss globalLoss) {
    final Region[] weakModels = new Region[count];
    final Mx transformedData = new VecBasedMx(learn.data().rows(), count);
    final CountDownLatch latch = new CountDownLatch(count);
    for (int i = 0; i < count; ++i) {
      final int ind = i;
      exec.submit(new Runnable() {
        @Override
        public void run() {
          weakModels[ind] = weak.fit(learn, DataTools.bootstrap(globalLoss, rnd));
          Mx applied = weakModels[ind].transAll(learn.data());
          for (int row = 0; row < learn.data().rows(); ++row) {
            transformedData.set(row, ind, applied.get(row, 0));
          }
          latch.countDown();
        }
      });
    }
    try {
      latch.await();
    } catch (Exception e) {
      System.err.println("fit error");
    }
    ElasticNetMethod lasso = new ElasticNetMethod(tolerance, alpha, lambda);
    Vec init = new ArrayVec(count);
    VecTools.fill(init, 0.0);
    Linear model = (Linear) lasso.fit(new VecDataSetImpl(transformedData, learn), globalLoss, init);
    return new Ensemble(weakModels, model.weights);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy