All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.trees.BFOptimizationSubset Maven / Gradle / Ivy

package com.expleague.ml.methods.trees;

import com.expleague.ml.data.Aggregate;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.ml.BFGrid;
import gnu.trove.list.array.TIntArrayList;


/**
 * User: solar
 * Date: 10.09.13
 * Time: 12:16
 */
@SuppressWarnings("unchecked")
public class BFOptimizationSubset {
  private final BinarizedDataSet bds;
  private int[] points;
  private final StatBasedLoss oracle;
  public final Aggregate aggregate;

  public BFOptimizationSubset(final BinarizedDataSet bds, final StatBasedLoss oracle, final int[] points) {
    this.bds = bds;
    this.points = points;
    this.oracle = oracle;
    this.aggregate = new Aggregate(bds, oracle.statsFactory(), points);
  }

  public BFOptimizationSubset split(final BFGrid.BinaryFeature feature) {
    final TIntArrayList left = new TIntArrayList(points.length);
    final TIntArrayList right = new TIntArrayList(points.length);
    final byte[] bins = bds.bins(feature.findex);
    for (final int i : points) {
      if (feature.value(bins[i])) {
        right.add(i);
      } else {
        left.add(i);
      }
    }
    final BFOptimizationSubset rightBro = new BFOptimizationSubset(bds, oracle, right.toArray());
    aggregate.remove(rightBro.aggregate);
    points = left.toArray();
    return rightBro;
  }


  public int size() {
    return points.length;
  }

  public void visitAllSplits(final Aggregate.SplitVisitor visitor) {
    aggregate.visit(visitor);
  }

  public  void visitSplit(final BFGrid.BinaryFeature bf, final Aggregate.SplitVisitor visitor) {
    final T left = (T) aggregate.combinatorForFeature(bf.bfIndex);
    final T right = (T) oracle.statsFactory().create().append(aggregate.total()).remove(left);
    visitor.accept(bf, left, right);
  }

  public AdditiveStatistics total() {
    return aggregate.total();
  }

  public int[] getPoints() {
    return points.clone();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy