All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.greedyRegion.BFOptimizationRegion Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.methods.greedyRegion;

import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.math.AnalyticFunc;
import com.expleague.ml.BFGrid;
import com.expleague.ml.data.Aggregate;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.impl.BinaryFeatureImpl;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import gnu.trove.list.array.TIntArrayList;


/**
 * User: solar
 * Date: 10.09.13
 * Time: 12:16
 */
@SuppressWarnings("unchecked")
public class BFOptimizationRegion {
  protected final BinarizedDataSet bds;
  protected int[] pointsInside;
  protected final StatBasedLoss oracle;
  protected Aggregate aggregate;

  public BFOptimizationRegion(final BinarizedDataSet bds,
                              final StatBasedLoss oracle,
                              final int[] points) {
    this.bds = bds;
    this.pointsInside = points;
    this.oracle = oracle;
    this.aggregate = new Aggregate(bds, oracle.statsFactory(), points);
  }


  public void split(final BFGrid.Feature feature, final boolean mask) {
    final byte[] bins = bds.bins(feature.findex());

    final TIntArrayList newInside = new TIntArrayList();
    final TIntArrayList newOutside = new TIntArrayList();


    for (int index : pointsInside) {
      if ((bins[index] > feature.bin()) != mask) {
        newOutside.add(index);
      } else {
        newInside.add(index);
      }
    }
    pointsInside = newInside.toArray();
    if (newInside.size() < newOutside.size()) {
      aggregate = new Aggregate(bds, oracle.statsFactory(), pointsInside);
    } else {
      aggregate.remove(new Aggregate(bds, oracle.statsFactory(), newOutside.toArray()));
    }
  }

  int size() {
    return pointsInside.length;
  }

  public void visitAllSplits(final Aggregate.SplitVisitor visitor) {
    aggregate.visit(visitor);
  }

  public  void visitSplit(final BinaryFeatureImpl bf, final Aggregate.SplitVisitor visitor) {
    final T left = (T) aggregate.combinatorForFeature(bf.bfIndex);
    final T right = (T) oracle.statsFactory().create().append(aggregate.total()).remove(left);
    visitor.accept(bf, left, right);
  }

  public AdditiveStatistics total() {
    return aggregate.total();
  }

  public static class PermutationWeightedFunc extends AnalyticFunc.Stub {
    private final int c;
    private final Aggregate aggregate;
    private final WeightedLoss loss;
    private final int[] order;

    public PermutationWeightedFunc(int c, int[] order, Aggregate aggregate, WeightedLoss loss) {
      this.c = c;
      this.order = order;
      this.aggregate = aggregate;
      this.loss = loss;
    }

    @Override
    public double value(double x) {
      double[] params = new double[]{0, 0, 0};
      aggregate.visitND(c, order.length, x, (k, N_k, D_k, P_k, S_k) -> {
        final int index = order[k];
        final double y_k = loss.target().get(index);
        final double w_k = loss.weight(index) * N_k / D_k;

        params[0] += w_k * y_k * y_k;
        params[1] += w_k * y_k;
        params[2] += w_k;
      });
      double sum2 = params[0];
      double sum = params[1];
      double weights = params[2];
      return sum2 - sum * sum / weights;
    }

    @Override
    public double gradient(double x) {
      final double[] params = new double[]{0};
      final WeightedLoss.Stat stat = (WeightedLoss.Stat) aggregate.total();
      final L2.Stat l2Stat = (L2.Stat)stat.inside;
      aggregate.visitND(c, order.length, x, (k, N_k, D_k, P_k, S_k) -> {
        final int index = order[k];
        final double y_k = loss.target().get(index);
        final double w_k = loss.weight(index) * N_k / D_k;

        final double dLdw = y_k * y_k - 2 * (y_k * l2Stat.sum * l2Stat.weight - l2Stat.sum * l2Stat.sum) / l2Stat.weight / l2Stat.weight / l2Stat.weight;
        final double dwdl = (S_k * D_k - P_k * N_k) / N_k / N_k;
        params[0] += w_k * dLdw * dwdl;
      });
      return params[0];
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy