All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.trees.LassoGreedyObliviousTree Maven / Gradle / Ivy

package com.expleague.ml.methods.trees;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.func.Linear;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.ElasticNetMethod;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.Pair;
import com.expleague.ml.Binarize;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.ModelTools;
import com.expleague.ml.models.ObliviousTree;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

/**
 * User: noxoomo
 * Date: 02.12.2015
 */

public class LassoGreedyObliviousTree extends VecOptimization.Stub {
  private final GreedyObliviousTree base;
  final int nlambda;
  final double alpha;

  public LassoGreedyObliviousTree(GreedyObliviousTree base, int nlambda, double alpha) {
    this.base = base;
    this.nlambda = nlambda;
    this.alpha = alpha;
  }

  private int[] learnPoints(Loss loss, VecDataSet ds) {
    if (loss instanceof WeightedLoss) {
      return ((WeightedLoss) loss).points();
    } else return ArrayTools.sequence(0, ds.length());
  }

  private int[] validationPoints(Loss loss) {
    if (loss instanceof WeightedLoss) {
      return ((WeightedLoss) loss).zeroPoints();
    } else {
      throw new RuntimeException("Wrong target type. No validation points");
    }
  }

  //TODO: noxoomo, remove duplicates
  //TODO: noxoomo, no intercept regularization…
  @SuppressWarnings("Duplicates")
  private Pair filter(final List entryList, final BinarizedDataSet bds, Vec sourceTarget, int[] points) {
    final byte[] binary = new byte[base.grid.rows()];
    Mx otData = new VecBasedMx(points.length, entryList.size());
    Vec target = new ArrayVec(points.length);


    for (int i = 0; i < points.length; ++i) {
      final int ind = points[i];
      for (int f = 0; f < base.grid.rows(); ++f) {
        binary[f] = bds.bins(f)[ind];
      }

      for (int j = 0; j < otData.columns(); ++j) {
        final int[] bfIndices = entryList.get(j).getBfIndices();
        double increment = 1.0;
        for (int k = 0; k < bfIndices.length; k++) {
          if (!base.grid.bf(bfIndices[k]).value(binary)) {
            increment = 0;
            break;
          }
        }
        otData.set(i, j, increment);
        target.set(i, sourceTarget.get(ind));
      }
    }
    return new Pair<>(otData, target);
  }

  @Override
  public ModelTools.CompiledOTEnsemble fit(final VecDataSet ds, final Loss loss) {
    ObliviousTree tree = base.fit(ds, loss);
    Ensemble ensemble = new Ensemble<>(new ObliviousTree[]{tree}, VecTools.fill(new SingleValueVec(1), 1.0));
    ModelTools.CompiledOTEnsemble compiledOTEnsemble = ModelTools.compile(ensemble);
    List entryList = compiledOTEnsemble.getEntries().stream()
            .filter(entry -> entry.getBfIndices().length > 0 && entry.getValue() != 0).collect(Collectors.toList());

    Vec target = VecTools.copy(loss.target());
    double bias = 0;
    for (int i = 0; i < target.dim(); ++i) {
      bias += target.get(i);
    }
    bias /= target.dim();
    for (int i = 0; i < target.dim(); ++i) {
      target.adjust(i, -bias);
    }


    final BinarizedDataSet bds = ds.cache().cache(Binarize.class, VecDataSet.class).binarize(base.grid);

    Pair compiledLearn = filter(entryList, bds, target, learnPoints(loss, ds));

    Vec entryBias = colMean(compiledLearn.first);
    center(compiledLearn.first, entryBias);


    Pair compiledValidate = filter(entryList, bds, target, validationPoints(loss));
    center(compiledValidate.first, entryBias);

    ElasticNetMethod lasso = new ElasticNetMethod(1e-4, alpha, 1.0);
    List weightsPath = lasso.fit(compiledLearn.first, compiledLearn.second, nlambda);
    double[] scores = weightsPath.parallelStream().mapToDouble(linear -> {
      final double testL2 = VecTools.distanceL2(linear.transAll(compiledValidate.first), compiledValidate.second) / compiledValidate.first.rows();
      final double learnL2 = VecTools.distanceL2(linear.transAll(compiledLearn.first), compiledLearn.second) / compiledLearn.first.rows();
      return 0.63 * testL2 + (1-0.63) * learnL2;
    }).toArray();
    int best = 0;
    double bestScore = scores[0];
    for (int i = 0; i < scores.length; ++i) {
      if (scores[i] < bestScore) {
        bestScore = scores[i];
        best = i;
      }
    }
    Vec weights = weightsPath.get(best).weights;
    ArrayList newEntries = new ArrayList<>();
    for (int i = 0; i < weights.dim(); ++i) {
      if (weights.get(i) != 0) {
        newEntries.add(new ModelTools.CompiledOTEnsemble.Entry(entryList.get(i).getBfIndices(), weights.get(i)));
        bias -= weights.get(i) * entryBias.get(i);
      }
    }
    if (bias != 0) {
      newEntries.add(new ModelTools.CompiledOTEnsemble.Entry(new int[0], bias));
    }

    System.out.println("Next entries batch: " + newEntries.size() + " of " + (entryList.size() + 1) + " nonZero");
    return new ModelTools.CompiledOTEnsemble(newEntries, tree.grid());
  }

  private Vec colMean(Mx data) {
    Vec result = new ArrayVec(data.columns());
    for (int i = 0; i < data.rows(); ++i) {
      for (int j = 0; j < data.columns(); ++j) {
        result.adjust(j, data.get(i, j));
      }
    }
    return VecTools.scale(result, 1.0 / data.rows());
  }

  private void center(Mx data, Vec bias) {
    for (int j = 0; j < data.rows(); ++j) {
      for (int i = 0; i < data.columns(); ++i) {
        data.adjust(i, j, -bias.get(i));
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy