All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.trees.RidgeGreedyObliviousTree Maven / Gradle / Ivy

package com.expleague.ml.methods.trees;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.Pair;
import com.expleague.ml.Binarize;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.methods.linearRegressionExperiments.RidgeRegression;
import com.expleague.ml.models.ModelTools;
import com.expleague.ml.models.ObliviousTree;

import java.util.ArrayList;
import java.util.List;

/**
 * User: noxoomo
 * Date: 02.12.2015
 */

public class RidgeGreedyObliviousTree extends VecOptimization.Stub {
  private final GreedyObliviousTree base;
  final double lambda;

  public RidgeGreedyObliviousTree(GreedyObliviousTree base, double lambda) {
    this.base = base;
    this.lambda = lambda;
  }

  private int[] learnPoints(Loss loss, VecDataSet ds) {
    if (loss instanceof WeightedLoss) {
      return ((WeightedLoss) loss).points();
    } else return ArrayTools.sequence(0, ds.length());
  }

  @SuppressWarnings("Duplicates")
  protected Pair filter(final List entryList, final BinarizedDataSet bds, Vec sourceTarget, int[] points) {
    final byte[] binary = new byte[base.grid.rows()];
    Mx otData = new VecBasedMx(points.length, entryList.size());
    Vec target = new ArrayVec(points.length);


    for (int i=0;i <  points.length;++i) {
      final int ind = points[i];
      for (int f=0; f (otData, target);
  }

  @Override
  public ModelTools.CompiledOTEnsemble  fit(final VecDataSet ds, final Loss loss) {
    ObliviousTree tree = base.fit(ds, loss);

    Ensemble ensemble = new Ensemble<>(new ObliviousTree[]{tree}, VecTools.fill(new SingleValueVec(1), 1.0));
    ModelTools.CompiledOTEnsemble compiledOTEnsemble = ModelTools.compile(ensemble);
    List entryList = compiledOTEnsemble.getEntries();

    final BinarizedDataSet bds =  ds.cache().cache(Binarize.class, VecDataSet.class).binarize(base.grid);

    Pair compiledLearn = filter(entryList, bds, loss.target(), learnPoints(loss, ds));
    RidgeRegression ridgeRegression = new RidgeRegression(lambda);
    Vec weights = ridgeRegression.fit(compiledLearn.first, compiledLearn.second);
    ArrayList newEntries  = new ArrayList<>();
    for (int i=0; i < weights.dim();++i) {
      newEntries.add(new ModelTools.CompiledOTEnsemble.Entry(entryList.get(i).getBfIndices(), weights.get(i)));
    }
    return new ModelTools.CompiledOTEnsemble(newEntries, tree.grid());
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy