All Downloads are FREE. Search and download functionalities are using the official Maven repository.

water.rapids.ASTGroup Maven / Gradle / Ivy

There is a newer version: 3.8.2.9
Show newest version
package water.rapids;

import water.H2O;
import water.Iced;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.Log;

import java.util.Arrays;

/** GroupBy
 *  Group the rows of 'data' by unique combinations of '[group-by-cols]'.
 *  Apply function 'fcn' to a Frame for each group, with a single column
 *  argument, and a NA-handling flag.  Sets of tuples {fun,col,na} are allowed.
 *
 *  'fcn' must be a one of a small set of functions, all reductions, and 'GB'
 *  returns a row per unique group, with the first columns being the grouping
 *  columns, and the last column(s) the reduction result(s).
 *
 *  The returned column(s).
 *  
 */
public class ASTGroup extends ASTPrim {
  public enum NAHandling { ALL, RM, IGNORE }

  // Functions handled by GroupBy
  public enum FCN {
    nrow() { 
      @Override public void op(double[] d0s, double d1) { d0s[0]++; }
      @Override public void atomic_op(double[] d0s, double[] d1s) { d0s[0] += d1s[0]; }
      @Override public double postPass(double ds[], long n) { return ds[0]; }
    },
    mean() { 
      @Override public void op(double[] d0s, double d1) { d0s[0]+=d1; }
      @Override public void atomic_op(double[] d0s, double[] d1s) { d0s[0] += d1s[0]; }
      @Override public double postPass(double ds[], long n) { return ds[0]/n; }
    },
    sum() { 
      @Override public void op(double[] d0s, double d1) { d0s[0]+=d1; }
      @Override public void atomic_op(double[] d0s, double[] d1s) { d0s[0] += d1s[0]; }
      @Override public double postPass(double ds[], long n) { return ds[0]; }
    },
    sumSquares() {
      @Override public void op(double[] d0s, double d1) { d0s[0]+=d1*d1; }
      @Override public void atomic_op(double[] d0s, double[] d1s) { d0s[0] += d1s[0]; }
      @Override public double postPass(double ds[], long n) { return ds[0]; }
    },
    var() {
      @Override public void op(double[] d0s, double d1) { d0s[0]+=d1*d1; d0s[1]+=d1; }
      @Override public void atomic_op(double[] d0s, double[] d1s) { ArrayUtils.add(d0s,d1s); }
      @Override public double postPass(double ds[], long n) {
        double numerator = ds[0] - ds[1]*ds[1]/n;
        if (Math.abs(numerator) < 1e-5) numerator = 0;
        return numerator/(n-1); 
      }
      @Override public double[] initVal(int ignored) { return new double[2]; /* 0 -> sum_squares; 1 -> sum*/}
    },
    sdev() {
      @Override public void op(double[] d0s, double d1) { d0s[0]+=d1*d1; d0s[1]+=d1; }
      @Override public void atomic_op(double[] d0s, double[] d1s) { ArrayUtils.add(d0s,d1s); }
      @Override public double postPass(double ds[], long n) {
        double numerator = ds[0] - ds[1]*ds[1]/n;
        if (Math.abs(numerator) < 1e-5) numerator = 0;
        return Math.sqrt(numerator/(n-1)); 
      }
      @Override public double[] initVal(int ignored) { return new double[2]; /* 0 -> sum_squares; 1 -> sum*/}
    },
    min() { 
      @Override public void op(double[] d0s, double d1) { d0s[0]= Math.min(d0s[0],d1); }
      @Override public void atomic_op(double[] d0s, double[] d1s) { op(d0s,d1s[0]); }
      @Override public double postPass(double ds[], long n) { return ds[0]; }
      @Override public double[] initVal(int maxx) { return new double[]{ Double.MAX_VALUE}; }
    },
    max() { 
      @Override public void op(double[] d0s, double d1) { d0s[0]= Math.max(d0s[0],d1); }
      @Override public void atomic_op(double[] d0s, double[] d1s) { op(d0s,d1s[0]); }
      @Override public double postPass(double ds[], long n) { return ds[0]; }
      @Override public double[] initVal(int maxx) { return new double[]{-Double.MAX_VALUE}; }
    },
    mode() { 
      @Override public void op(double[] d0s, double d1) { d0s[(int)d1]++; }
      @Override public void atomic_op(double[] d0s, double[] d1s) { ArrayUtils.add(d0s,d1s); }
      @Override public double postPass(double ds[], long n) { return ArrayUtils.maxIndex(ds); }
      @Override public double[] initVal(int maxx) { return new double[maxx]; }
    },
    ;
    public abstract void op(double[] d0, double d1);
    public abstract void atomic_op(double[] d0, double[] d1);
    public abstract double postPass(double ds[], long n);
    public double[] initVal(int maxx) { return new double[]{0}; }
  }

  @Override int nargs() { return -1; } // (GB data [group-by-cols] {fcn col "na"}...)
  @Override public String[] args() { return new String[]{"..."}; }
  @Override public String str() { return "GB"; }
  @Override public Val apply(Env env, Env.StackHelp stk, AST asts[]) {
    Frame fr = stk.track(asts[1].exec(env)).getFrame();
    int ncols = fr.numCols();

    ASTNumList groupby = check(ncols, asts[2]);
    final int[] gbCols = groupby.expand4();

    // Count of aggregates; knock off the first 4 ASTs (GB data [group-by] [order-by]...),
    // then count by triples.
    int naggs = (asts.length-3)/3;
    final AGG[] aggs = new AGG[naggs];
    for( int idx = 3; idx < asts.length; idx += 3 ) {
      Val v = asts[idx].exec(env);
      String fn = v instanceof ValFun ? v.getFun().str() : v.getStr();
      FCN fcn = FCN.valueOf(fn);
      ASTNumList col = check(ncols,asts[idx+1]);
      if( col.cnt() != 1 ) throw new IllegalArgumentException("Group-By functions take only a single column");
      int agg_col = (int)col.min(); // Aggregate column
      if( fcn==FCN.mode && !fr.vec(agg_col).isCategorical() )
        throw new IllegalArgumentException("Mode only allowed on categorical columns");
      NAHandling na = NAHandling.valueOf(asts[idx+2].exec(env).getStr().toUpperCase());
      aggs[(idx-3)/3] = new AGG(fcn,agg_col,na, (int)fr.vec(agg_col).max()+1);
    }

    // do the group by work now
    IcedHashMap gss = doGroups(fr,gbCols,aggs);
    final G[] grps = gss.keySet().toArray(new G[gss.size()]);

    // apply an ORDER by here...
    if( gbCols.length > 0 )
      Arrays.sort(grps,new java.util.Comparator() {
          // Compare 2 groups.  Iterate down _gs, stop when _gs[i] > that._gs[i],
          // or _gs[i] < that._gs[i].  Order by various columns specified by
          // gbCols.  NaN is treated as least
          @Override public int compare( G g1, G g2 ) {
            for( int i=0; i doGroups(Frame fr, int[] gbCols, AGG[] aggs) {
    // do the group by work now
    long start = System.currentTimeMillis();
    GBTask p1 = new GBTask(gbCols, aggs).doAll(fr);
    Log.info("Group By Task done in " + (System.currentTimeMillis() - start)/1000. + " (s)");
    return p1._gss;
  }

  // Utility for ASTDdply; return a single aggregate for counting rows-per-group
  static AGG[] aggNRows() { return new AGG[]{new AGG(FCN.nrow,0,NAHandling.IGNORE,0)};  }

  // Build output frame from the multi-column results
  static Frame buildOutput(int[] gbCols, int noutCols, Frame fr, String[] fcnames, int ngrps, MRTask mrfill) {
    // Build the output!
    // the names of columns
    final int nCols = gbCols.length+noutCols;
    String[] names = new String[nCols];
    String[][] domains = new String[nCols][];
    for( int i=0;i {
    final IcedHashMap _gss; // Shared per-node, common, racy
    private final int[] _gbCols; // Columns used to define group
    private final AGG[] _aggs;   // Aggregate descriptions
    GBTask(int[] gbCols, AGG[] aggs) { _gbCols=gbCols; _aggs=aggs; _gss = new IcedHashMap<>(); }
    @Override public void map(Chunk[] cs) {
      // Groups found in this Chunk
      IcedHashMap gs = new IcedHashMap<>();
      G gWork = new G(_gbCols.length,_aggs); // Working Group
      G gOld;                   // Existing Group to be filled in
      for( int row=0; row r ) {
      for( G rg : r.keySet() )
        if( _gss.putIfAbsent(rg,"")!=null ) {
          G lg = _gss.getk(rg);
          for( int i=0; i<_aggs.length; i++ )
            _aggs[i].atomic_op(lg._dss,lg._ns,i, rg._dss[i], rg._ns[i]); // Need to atomically merge groups here
        }
    }
  }

  // Groups!  Contains a Group Key - an array of doubles (often just 1 entry
  // long) that defines the Group.  Also contains an array of doubles for the
  // aggregate results, one per aggregate.
  public static class G extends Iced {
    final double _gs[];  // Group Key: Array is final; contents change with the "fill"
    int _hash;           // Hash is not final; changes with the "fill"

    public final double _dss[][];      // Aggregates: usually sum or sum*2
    public final long   _ns[];         // row counts per aggregate, varies by NA handling and column

    public G( int ncols, AGG[] aggs ) {
      _gs = new double[ncols]; 
      int len = aggs==null ? 0 : aggs.length;
      _dss= new double[len][];
      _ns = new long  [len]; 
      for( int i=0; i>>20) ^ (h>>>12);
      h ^= (h>>> 7) ^ (h>>> 4);
      return (int)((h^(h>>32))&0x7FFFFFFF);
    }
    @Override public boolean equals( Object o ) {
      return o instanceof G && Arrays.equals(_gs, ((G) o)._gs); }
    @Override public int hashCode() { return _hash; }
    @Override public String toString() { return Arrays.toString(_gs); }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy