All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.trees.GreedyObliviousTreeWithWeakLearner Maven / Gradle / Ivy

package com.expleague.ml.methods.trees;

import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.ObliviousTree;
import com.expleague.ml.BFGrid;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.linearRegressionExperiments.WeakLeastAngle;

import java.util.List;
import java.util.Set;
import java.util.TreeSet;

/**
 * User: noxoomo
 */

public class GreedyObliviousTreeWithWeakLearner extends VecOptimization.Stub {
  private final GreedyObliviousTree> base;
  private final FastRandom rand;

  public GreedyObliviousTreeWithWeakLearner(
    final GreedyObliviousTree> base,
    final FastRandom rand) {
    this.base = base;
    this.rand = rand;
  }



  private int[] learnPoints(WeightedLoss loss) {
    return loss.points();
  }

  @Override
  public Trans fit(final VecDataSet ds, final Loss loss) {
    final WeightedLoss bsLoss = DataTools.bootstrap(loss, rand);
    final Trans[] result = new Trans[2];
    result[0] = base.fit(ds, bsLoss);

    final List conditions = ((ObliviousTree)result[0]).features();
    //damn java 7 without unique, filters, etc and autoboxing overhead…
    Set uniqueFeatures = new TreeSet<>();
    for (BFGrid.BinaryFeature bf : conditions) {
      if (!bf.row().empty()
        )
        uniqueFeatures.add(bf.findex);
    }
//    //prototype
    while (uniqueFeatures.size() < 10) {
      final int feature = rand.nextInt(ds.data().columns());
      if (!base.grid.row(feature).empty())
        uniqueFeatures.add(feature);
    }

    Vec newTarget = VecTools.copy(loss.target());
    Vec predictions = result[0].transAll(ds.data()).col(0);
    for (int i = 0; i < predictions.dim(); ++i)
      newTarget.adjust(i, -predictions.get(i));


    final int[] features = new int[uniqueFeatures.size()];
    {
      int j = 0;
      for (Integer i : uniqueFeatures) {
        features[j++] = i;
      }
    }

    L2 localLoss = DataTools.newTarget(L2.class,newTarget,ds);

    WeakLeastAngle regression = new WeakLeastAngle(learnPoints(bsLoss), features);
    result[1] = regression.fit(ds,localLoss);

    Vec weights = new ArrayVec(2);
    VecTools.fill(weights,1.0);

    Ensemble ensemble = new Ensemble(result, weights);
    return ensemble;
}

}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy