All Downloads are FREE. Search and download functionalities are using the official Maven repository.

water.util.MRUtils Maven / Gradle / Ivy

There is a newer version: 3.8.2.9
Show newest version
package water.util;

import water.*;
import water.H2O.H2OCallback;
import water.H2O.H2OCountedCompleter;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashMap;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;

import static water.util.RandomUtils.getRNG;

public class MRUtils {


  /**
   * Sample rows from a frame.
   * Can be unlucky for small sampling fractions - will continue calling itself until at least 1 row is returned.
   * @param fr Input frame
   * @param rows Approximate number of rows to sample (across all chunks)
   * @param seed Seed for RNG
   * @return Sampled frame
   */
  public static Frame sampleFrame(Frame fr, final long rows, final long seed) {
    if (fr == null) return null;
    final float fraction = rows > 0 ? (float)rows / fr.numRows() : 1.f;
    if (fraction >= 1.f) return fr;
    Key newKey = fr._key != null ? Key.make(fr._key.toString() + (fr._key.toString().contains("temporary") ? ".sample." : ".temporary.sample.") + PrettyPrint.formatPct(fraction).replace(" ","")) : null;

    Frame r = new MRTask() {
      @Override
      public void map(Chunk[] cs, NewChunk[] ncs) {
        final Random rng = getRNG(0);
        int count = 0;
        for (int r = 0; r < cs[0]._len; r++) {
          rng.setSeed(seed+r+cs[0].start());
          if (rng.nextFloat() < fraction || (count == 0 && r == cs[0]._len-1) ) {
            count++;
            for (int i = 0; i < ncs.length; i++) {
              ncs[i].addNum(cs[i].atd(r));
            }
          }
        }
      }
    }.doAll(fr.types(), fr).outputFrame(newKey, fr.names(), fr.domains());
    if (r.numRows() == 0) {
      Log.warn("You asked for " + rows + " rows (out of " + fr.numRows() + "), but you got none (seed=" + seed + ").");
      Log.warn("Let's try again. You've gotta ask yourself a question: \"Do I feel lucky?\"");
      return sampleFrame(fr, rows, seed+1);
    }
    return r;
  }

  /**
   * Row-wise shuffle of a frame (only shuffles rows inside of each chunk)
   * @param fr Input frame
   * @return Shuffled frame
   */
  public static Frame shuffleFramePerChunk(Frame fr, final long seed) {
    return new MRTask() {
      @Override
      public void map(Chunk[] cs, NewChunk[] ncs) {
        int[] idx = new int[cs[0]._len];
        for (int r=0; r {
    final int _nclass;
    protected double[] _ys;
    public ClassDist(final Vec label) { _nclass = label.domain().length; }
    public ClassDist(int n) { _nclass = n; }

    public final double[] dist() { return _ys; }
    public final double[] rel_dist() {
      final double sum = ArrayUtils.sum(_ys);
      return ArrayUtils.div(Arrays.copyOf(_ys, _ys.length), sum);
    }
    @Override public void map(Chunk ys) {
      _ys = new double[_nclass];
      for( int i=0; i {
    private IcedHashMap _dist;
    @Override public void map(Chunk ys) {
      _dist = new IcedHashMap<>();
      for( int row=0; row< ys._len; row++ )
        if( !ys.isNA(row) ) {
          double v = ys.atd(row);
          Integer oldV = _dist.putIfAbsent(v,1);
          if( oldV!=null ) _dist.put(v,oldV+1);
        }
    }

    @Override public void reduce(Dist mrt) {
      if( _dist != mrt._dist ) {
        IcedHashMap l = _dist;
        IcedHashMap r = mrt._dist;
        if( l.size() < r.size() ) { l=r; r=_dist; }
        for( Double v: r.keySet() ) {
          Integer oldVal = l.putIfAbsent(v, r.get(v));
          if( oldVal!=null ) l.put(v, oldVal+r.get(v));
        }
        _dist=l;
        mrt._dist=null;
      }
    }
    public double[] dist() {
      int i=0;
      double[] dist = new double[_dist.size()];
      for( int v: _dist.values() ) dist[i++] = v;
      return dist;
    }
    public double[] keys() {
      int i=0;
      double[] keys = new double[_dist.size()];
      for( double v: _dist.keySet() ) keys[i++] = v;
      return keys;
    }
  }


  /**
   * Stratified sampling for classifiers - FIXME: For weights, this is not accurate, as the sampling is done with uniform weights
   * @param fr Input frame
   * @param label Label vector (must be categorical)
   * @param weights Weights vector, can be null
   * @param sampling_ratios Optional: array containing the requested sampling ratios per class (in order of domains), will be overwritten if it contains all 0s
   * @param maxrows Maximum number of rows in the returned frame
   * @param seed RNG seed for sampling
   * @param allowOversampling Allow oversampling of minority classes
   * @param verbose Whether to print verbose info
   * @return Sampled frame, with approximately the same number of samples from each class (or given by the requested sampling ratios)
   */
  public static Frame sampleFrameStratified(final Frame fr, Vec label, Vec weights, float[] sampling_ratios, long maxrows, final long seed, final boolean allowOversampling, final boolean verbose) {
    if (fr == null) return null;
    assert(label.isCategorical());
    if (maxrows < label.domain().length) {
      Log.warn("Attempting to do stratified sampling to fewer samples than there are class labels - automatically increasing to #rows == #labels (" + label.domain().length + ").");
      maxrows = label.domain().length;
    }

    ClassDist cd = new ClassDist(label);
    double[] dist = weights != null ? cd.doAll(label, weights).dist() : cd.doAll(label).dist();
    assert(dist.length > 0);
    Log.info("Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off"));
    if (verbose)
      for (int i=0; i= 0); //can have no matching rows in case of sparse data where we had to fill in a makeZero() vector
    Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows" + (actualnumrows < numrows ? " (limited by max_after_balance_size).":"."));

    if (actualnumrows != numrows) {
      ArrayUtils.mult(sampling_ratios, (float)actualnumrows/numrows); //adjust the sampling_ratios by the global rescaling factor
      if (verbose)
        Log.info("Downsampling majority class by " + (float)actualnumrows/numrows
                + " to limit number of rows to " + String.format("%,d", maxrows));
    }
    for (int i=0;i= 0);
    final int weightsidx = fr.find(weights); //which column is the weight?

    final boolean poisson = false; //beta feature

    //FIXME - this is doing uniform sampling, even if the weights are given
    Frame r = new MRTask() {
      @Override
      public void map(Chunk[] cs, NewChunk[] ncs) {
        final Random rng = getRNG(seed);
        for (int r = 0; r < cs[0]._len; r++) {
          if (cs[labelidx].isNA(r)) continue; //skip missing labels
          rng.setSeed(cs[0].start()+r+seed);
          final int label = (int)cs[labelidx].at8(r);
          assert(sampling_ratios.length > label && label >= 0);
          int sampling_reps;
          if (poisson) {
            throw H2O.unimpl();
//            sampling_reps = ArrayUtils.getPoisson(sampling_ratios[label], rng);
          } else {
            final float remainder = sampling_ratios[label] - (int)sampling_ratios[label];
            sampling_reps = (int)sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0);
          }
          for (int i = 0; i < ncs.length; i++) {
            for (int j = 0; j < sampling_reps; ++j) {
              ncs[i].addNum(cs[i].atd(r));
            }
          }
        }
      }
    }.doAll(fr.types(), fr).outputFrame(fr.names(), fr.domains());

    // Confirm the validity of the distribution
    Vec lab = r.vecs()[labelidx];
    Vec wei = weightsidx != -1 ? r.vecs()[weightsidx] : null;
    double[] dist = wei != null ? new ClassDist(lab).doAll(lab, wei).dist() : new ClassDist(lab).doAll(lab).dist();

    // if there are no training labels in the test set, then there is no point in sampling the test set
    if (dist == null) return fr;

    if (debug) {
      double sumdist = ArrayUtils.sum(dist);
      Log.info("After stratified sampling: " + sumdist + " rows.");
      for (int i=0; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy