All Downloads are FREE. Search and download functionalities are using the official Maven repository.

water.util.MRUtils Maven / Gradle / Ivy

package water.util;

import static water.util.RandomUtils.getRNG;

import water.*;
import water.H2O.H2OCallback;
import water.H2O.H2OCountedCompleter;
import water.fvec.*;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;

public class MRUtils {


  /**
   * Sample rows from a frame.
   * Can be unlucky for small sampling fractions - will continue calling itself until at least 1 row is returned.
   * @param fr Input frame
   * @param rows Approximate number of rows to sample (across all chunks)
   * @param seed Seed for RNG
   * @return Sampled frame
   */
  public static Frame sampleFrame(Frame fr, final long rows, final long seed) {
    if (fr == null) return null;
    final float fraction = rows > 0 ? (float)rows / fr.numRows() : 1.f;
    if (fraction >= 1.f) return fr;
    Key newKey = fr._key != null ? Key.make(fr._key.toString() + (fr._key.toString().contains("temporary") ? ".sample." : ".temporary.sample.") + PrettyPrint.formatPct(fraction).replace(" ","")) : null;

    Frame r = new MRTask() {
      @Override
      public void map(Chunk[] cs, NewChunk[] ncs) {
        final Random rng = getRNG(seed + cs[0].cidx());
        int count = 0;
        for (int r = 0; r < cs[0]._len; r++)
          if (rng.nextFloat() < fraction || (count == 0 && r == cs[0]._len-1) ) {
            count++;
            for (int i = 0; i < ncs.length; i++) {
              ncs[i].addNum(cs[i].atd(r));
            }
          }
      }
    }.doAll(fr.numCols(), fr).outputFrame(newKey, fr.names(), fr.domains());
    if (r.numRows() == 0) {
      Log.warn("You asked for " + rows + " rows (out of " + fr.numRows() + "), but you got none (seed=" + seed + ").");
      Log.warn("Let's try again. You've gotta ask yourself a question: \"Do I feel lucky?\"");
      return sampleFrame(fr, rows, seed+1);
    }
    return r;
  }

  /**
   * Row-wise shuffle of a frame (only shuffles rows inside of each chunk)
   * @param fr Input frame
   * @return Shuffled frame
   */
  public static Frame shuffleFramePerChunk(Frame fr, final long seed) {
    return new MRTask() {
      @Override
      public void map(Chunk[] cs, NewChunk[] ncs) {
        int[] idx = new int[cs[0]._len];
        for (int r=0; r {
    final int _nclass;
    protected double[] _ys;
    public ClassDist(final Vec label) { _nclass = label.domain().length; }
    public ClassDist(int n) { _nclass = n; }

    public final double[] dist() { return _ys; }
    public final double[] rel_dist() {
      final double sum = ArrayUtils.sum(_ys);
      return ArrayUtils.div(Arrays.copyOf(_ys, _ys.length), sum);
    }
    @Override public void map(Chunk ys) {
      _ys = new double[_nclass];
      for( int i=0; i 0);
    Log.info("Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off"));
    if (verbose)
      for (int i=0; i= 0); //can have no matching rows in case of sparse data where we had to fill in a makeZero() vector
    Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows" + (actualnumrows < numrows ? " (limited by max_after_balance_size).":"."));

    if (actualnumrows != numrows) {
      ArrayUtils.mult(sampling_ratios, (float)actualnumrows/numrows); //adjust the sampling_ratios by the global rescaling factor
      if (verbose)
        Log.info("Downsampling majority class by " + (float)actualnumrows/numrows
                + " to limit number of rows to " + String.format("%,d", maxrows));
    }
    for (int i=0;i= 0);
    final int weightsidx = fr.find(weights); //which column is the weight?

    final boolean poisson = false; //beta feature

    //FIXME - this is doing uniform sampling, even if the weights are given
    Frame r = new MRTask() {
      @Override
      public void map(Chunk[] cs, NewChunk[] ncs) {
        final Random rng = getRNG(seed + cs[0].cidx());
        for (int r = 0; r < cs[0]._len; r++) {
          if (cs[labelidx].isNA(r)) continue; //skip missing labels
          final int label = (int)cs[labelidx].at8(r);
          assert(sampling_ratios.length > label && label >= 0);
          int sampling_reps;
          if (poisson) {
            throw H2O.unimpl();
//            sampling_reps = ArrayUtils.getPoisson(sampling_ratios[label], rng);
          } else {
            final float remainder = sampling_ratios[label] - (int)sampling_ratios[label];
            sampling_reps = (int)sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0);
          }
          for (int i = 0; i < ncs.length; i++) {
            for (int j = 0; j < sampling_reps; ++j) {
              ncs[i].addNum(cs[i].atd(r));
            }
          }
        }
      }
    }.doAll(fr.numCols(), fr).outputFrame(fr.names(), fr.domains());

    // Confirm the validity of the distribution
    Vec lab = r.vecs()[labelidx];
    Vec wei = weightsidx != -1 ? r.vecs()[weightsidx] : null;
    double[] dist = wei != null ? new ClassDist(lab).doAll(lab, wei).dist() : new ClassDist(lab).doAll(lab).dist();

    // if there are no training labels in the test set, then there is no point in sampling the test set
    if (dist == null) return fr;

    if (debug) {
      double sumdist = ArrayUtils.sum(dist);
      Log.info("After stratified sampling: " + sumdist + " rows.");
      for (int i=0; i> extends H2OCountedCompleter {
    public transient final T [] _tasks;
    transient final public int _maxP;
    transient private AtomicInteger _nextTask;

    public ParallelTasks(H2OCountedCompleter cmp, T[] tsks){
      this(cmp,tsks,H2O.CLOUD.size());
    }
    public ParallelTasks(H2OCountedCompleter cmp, T[] tsks, int maxP){
      super(cmp);
      _maxP = maxP;
      _tasks = tsks;
      addToPendingCount(_tasks.length-1);
    }

    private void forkDTask(int i){
      int nodeId = i%H2O.CLOUD.size();
      forkDTask(i,H2O.CLOUD._memary[nodeId]);
    }
    private void forkDTask(final int i, H2ONode n){
      if(n == H2O.SELF) {
        _tasks[i].setCompleter(new Callback(H2O.SELF,i));
        H2O.submitTask(_tasks[i]);
      } else
        new RPC(n,_tasks[i]).addCompleter(this).call();
    }
    class Callback extends H2OCallback {
      final int i;
      final H2ONode n;

      public Callback(H2ONode n, int i){
        super(ParallelTasks.this); this.n = n; this.i = i;
      }
      @Override public void callback(H2OCountedCompleter cc){
        Log.info("callback for task " + i);
        int nextI;
        if((nextI = _nextTask.getAndIncrement()) < _tasks.length)  // not done yet
          forkDTask(nextI, n);
      }
    }
    @Override public void compute2(){
      final int n = Math.min(_maxP, _tasks.length);
      _nextTask = new AtomicInteger(n);
      for(int i = 0; i < n; ++i)
        forkDTask(i);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy