All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.greedyRegion.BFWeakConditionsStochasticOptimizationRegion Maven / Gradle / Ivy

package com.expleague.ml.methods.greedyRegion;

import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.random.FastRandom;
import com.expleague.ml.BFGrid;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.impl.BinaryFeatureImpl;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.methods.trees.BFOptimizationSubset;
import com.expleague.commons.util.ArrayTools;
import gnu.trove.list.array.TIntArrayList;


/**
 * User: solar
 * Date: 10.09.13
 * Time: 12:16
 */
@SuppressWarnings("unchecked")
public class BFWeakConditionsStochasticOptimizationRegion extends BFWeakConditionsOptimizationRegion {
  private final FastRandom random = new FastRandom();
  double alpha = 0.02;
  double beta = 0.5;

  public BFWeakConditionsStochasticOptimizationRegion(
      final BinarizedDataSet bds, final StatBasedLoss oracle, final int[] points, final BFGrid.Feature[] features, final boolean[] masks, final int maxFailed) {
    super(bds, oracle, points, features, masks, maxFailed);
  }

  @Override
  public BFOptimizationSubset split(final BFGrid.Feature feature, final boolean mask) {
    final TIntArrayList out = new TIntArrayList(points.length);
    final byte[] bins = bds.bins(feature.findex());
    final TIntArrayList newCriticalPoints = new TIntArrayList();
    final AdditiveStatistics newCritical = oracle.statsFactory().create();

    final int nonCriticalEnd = maxFailed > 0 ? failedBorders[maxFailed - 1] : 0;
    for (int i = 0; i < nonCriticalEnd; ++i) {
      final int index = points[i];
      if ((bins[index] > feature.bin()) != mask) {
        failedCount[i]++;
        if (failedCount[i] == maxFailed) {
          newCriticalPoints.add(index);
          newCritical.append(index, 1);
        }
      }
    }
    final boolean[] failed = splitCritical(points, nonCriticalEnd, failedBorders[maxFailed], feature, mask, bins);
    for (int i = nonCriticalEnd; i < failedBorders[maxFailed]; ++i) {
      final int index = points[i];
      if (failed[i - nonCriticalEnd]) {
        excluded.append(index, 1);
        failedCount[i]++;
        out.add(index);
      }
    }

    final BFOptimizationSubset outRegion = new BFOptimizationSubset(bds, oracle, out.toArray());
    aggregate.remove(outRegion.aggregate);
    aggregate.append(newCriticalPoints.toArray());
    nonCriticalTotal.remove(newCritical);
    ArrayTools.parallelSort(failedCount, points, 0, failedBorders[maxFailed] - 1);
    updateFailedBorders(failedCount, failedBorders);
    return outRegion;
  }

  private boolean[] splitCritical(final int[] points, final int left, final int right, final BFGrid.Feature feature, final boolean mask, final byte[] bins) {
    final boolean[] result = new boolean[right - left];
//    for (int i = left; i < right;++i) {
//      final int index = points[i];
//      final double diff = mask ? bins[index] - feature.binNo - 1 : feature.binNo - bins[index];
//      result[i-left] = random.nextDouble() >= Math.pow(0.5, -diff / 1.3);
//    }
//    return result;
    final double[] values = new double[right - left];
    final Vec featureValues = ((VecDataSet) bds.original()).data().col(feature.findex());
    final int[] order = ArrayTools.sequence(0, values.length);
    for (int i = 0; i < values.length; ++i) {
      final int index = points[i + left];
      values[i] = featureValues.get(index);
    }
    ArrayTools.parallelSort(values, order);
    final double[] ranks = rank(values);
    final int split = upperBound(values, feature.condition());
    for (int i = 0; i < values.length; ++i) {
      if ((values[i] > feature.condition()) != mask) {
        //if !mask, than diff = #points <= point - #points in left
        //if mask, than diff = #points in left - #points < point
        //points in left = split
        final double diff = mask ? split - ranks[i] + 1 : ranks[i] - split;
        result[order[i]] = random.nextDouble() >= Math.pow(0.5, alpha * diff);
      } else {
        //if mask, than diff =  #points <= point - #points in left
        //if !mask, than diff = #points in left - #points < point
        //points in left = split
        final double diff = mask ? ranks[i] - split : split - ranks[i] + 1;
        result[order[i]] = random.nextDouble() <= Math.pow(0.5, beta * diff);
      }
    }
    return result;
  }

  private double[] rank(final double[] sortedSample) {
    final double[] ranks = new double[sortedSample.length];
    for (int i = 0; i < sortedSample.length; ++i) {
      int j = i + 1;
      while (j < sortedSample.length && Math.abs(sortedSample[j] - sortedSample[j - 1]) < 1e-9) ++j;
      final double rk = i + 0.5 * (j - i);
      for (; i < j; ++i) {
        ranks[i] = rk;
      }
      --i;
    }
//    {
//      for (int i = 0; i < sortedSample.length; ++i) {
//        int less = 0;
//        int equals = 0;
//        for (int j = 0; j < sortedSample.length; ++j) {
//          if (Math.abs(sortedSample[i] - sortedSample[j]) < 1e-9)
//            ++equals;
//          else if (sortedSample[i] > sortedSample[j]) {
//            ++less;
//          }
//        }
//        if (ranks[i] != less + equals * 0.5) {
//          System.out.println("error");
//        }
//      }
//    }
    return ranks;
  }


  //java version doesn't guarantee, that we'll find last entry
  //should return first index, that greater than key
  private int upperBound(final double[] arr, final double key) {
    int left = 0;
    int right = arr.length;
    while (right - left > 1) {
      final int mid = (left + right) >>> 1;
      final double midVal = arr[mid];
      if (midVal <= key)
        left = mid;
      else
        right = mid;
    }
    return right;
  }

  //should return last index, that less than key +1
  private int lowerBound(final double[] arr, final double key) {
    int left = 0;
    int right = arr.length;
    while ((right - left) > 1) {
      final int mid = (left + right) >>> 1;
      final double midVal = arr[mid];
      if (midVal < key)
        left = mid;
      else
        right = mid;
    }
    return right;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy