All Downloads are FREE. Search and download functionalities are using the official Maven repository.

water.rapids.ASTGroupBy Maven / Gradle / Ivy

package water.rapids;


import sun.misc.Unsafe;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashSet;
import water.nbhm.UtilUnsafe;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.Log;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;


/**
 * GROUPBY: Single pass aggregation by columns.
 *
 * NA handling:
 *
 *  AGG.T_IG: case 0
 *    Count NA rows, but discard values in sums, mins, maxs
 *      FIRST/LAST return the first nonNA first/last, or NA if all NA
 *
 *  AGG.T_RM: case 1
 *    Count NA rows separately, discard values in sums, mins, maxs and compute aggregates less NA row counts
 *      FIRST/LAST treated as above
 *
 *  AGG.T_ALL: case 2
 *    Include NA in all aggregates -- any NA encountered forces aggregate to be NA.
 *      FIRST/LAST return first/last row regardless of NAs.
 *
 * Aggregates:
 *  MIN
 *  MAX
 *  MEAN
 *  COUNT
 *  SUM
 *  SD
 *  VAR
 *  COUNT_DISTINCT
 *  FIRST
 *  LAST
 *  MODE
 *  Aggregations on time and numeric columns only.
 */
  public class ASTGroupBy extends ASTUniPrefixOp {
  // AST: (GB fr cols AGGS ORDERBY)
  //      (GB %k (llist #1;#3) (AGGS #2 "min" #4 "mean" #6) ())  for no order by..., otherwise is a llist or single number
  private long[] _gbCols; // group by columns
  private AGG[] _agg;
  private AST[] _gbColsDelayed;
  private String[] _gbColsDelayedByName;
  private long[] _orderByCols;
  ASTGroupBy() { super(null); }
  @Override String opStr() { return "GB"; }
  @Override ASTOp make() {return new ASTGroupBy();}
  ASTGroupBy parse_impl(Exec E) {
    AST ary = E.parse();
    // parse gby columns
    AST s = E.parse();
    if( s instanceof ASTLongList ) _gbCols = ((ASTLongList)s)._l;
    else if( s instanceof ASTNum ) _gbCols = new long[]{(long)((ASTNum)s)._d};
    else if( s instanceof ASTAry ) _gbColsDelayed = ((ASTAry)s)._a;
    else if( s instanceof  ASTStringList) _gbColsDelayedByName = ((ASTStringList)s)._s;
    else if( s instanceof ASTDoubleList ) {
      double[] d = ((ASTDoubleList)s)._d;
      _gbCols = new long[d.length];
      for(int i=0;i()));
      if( e.isAry()      ) res[i++] = f.find(e.popAry().anyVec());
      else if( e.isNum() ) res[i++] = (int)e.popDbl();
      else if( e.isStr() ) res[i++] = f.find(e.popStr());
      else throw new IllegalArgumentException("Don't know what to do with: " + ast.getClass() + "; " + e.pop());
    }
    return res;
  }

  private void computeCols(AGG[] aggs, Frame f) {
    for(AGG a:aggs ) {
      if( a._c == null ) {
        if( a._delayedColByName!=null ) a._c = f.find(a._delayedColByName);
        else if( a._delayedCol!=null  ) {
          Env e = treeWalk(new Env(new HashSet()));
          if( e.isAry() ) a._c = f.find(e.popAry().anyVec());
          else if( e.isNum() ) a._c = (int)e.popDbl();
          else if( e.isStr() ) a._c = f.find(e.popStr());
          else throw new IllegalArgumentException("No column found for: " + e.pop());
        }
        else throw new IllegalArgumentException("Missing column for aggregate: " + a._name);
      }
    }
  }

  public static class IcedNBHS extends Iced implements Iterable {
    NonBlockingHashSet _g;
    IcedNBHS() {_g=new NonBlockingHashSet<>();}
    boolean add(T t) { return _g.add(t); }
    boolean addAll(Collection c) { return _g.addAll(c); }
    T get(T g) { return _g.get(g); }
    int size() { return _g.size(); }
    @Override public AutoBuffer write_impl( AutoBuffer ab ) {
      if( _g == null ) return ab.put4(0);
      ab.put4(_g.size());
      for( T g: _g ) ab.put(g);
      return ab;
    }
    @Override public IcedNBHS read_impl(AutoBuffer ab) {
      int len = ab.get4();
      if( len == 0 ) return this;
      _g = new NonBlockingHashSet<>();
      for( int i=0;i iterator() {return _g.iterator(); }
  }

  public static class GBTask extends MRTask {
    IcedHashMap _g;
    private long[] _gbCols;
    private AGG[] _agg;
    GBTask(long[] gbCols, AGG[] agg) { _gbCols=gbCols; _agg=agg; }
    @Override public void setupLocal() {_g=new IcedHashMap<>();}
    @Override public void map(Chunk[] c) {
      long start = c[0].start();
      G g = new G(_gbCols.length,_agg);
      G gOld;  // fill this one in for all the CAS'ing
      for( int i=0;i l = _g;
        IcedHashMap r = t._g;
        if( l.size() < r.size() ) { l=r; r=_g; }  // larger on the left
        // loop over the smaller set of grps
        for( G rg:r.keySet() ) {
          G lg = l.getk(rg);
          if( l.putIfAbsent(rg,"")!=null ) {
            assert lg!=null;
            long R = lg._N;
            while (!G.CAS_N(lg, R, R + rg._N))
              R = lg._N;
            reduceGroup(_agg, lg, rg);
          }
        }
        _g=l;
        t._g=null;
      }
    }
    // task helper functions
    private static void perRow(AGG[] agg, int chkRow, long rowOffset, Chunk[] c, G g) { perRow(agg,chkRow,rowOffset,c,g,null); }
    private static void reduceGroup(AGG[] agg, G g, G that) { perRow(agg,-1,-1,null,g,that);}
    private static void perRow(AGG[] agg, int chkRow, long rowOffset, Chunk[] c, G g, G that) {
      byte type; int col;
      for( int i=0;i o && !G.CAS_l(g, G.longRawIdx(c), o, v))
        o = g._l[c];
    }
    private static void setMin(G g, long v, int c) {
      double o = g._min[c];
      double vv = Double.longBitsToDouble(v);
      while( vv < o && !G.CAS_min(g,G.doubleRawIdx(c),Double.doubleToRawLongBits(o),v))
        o = g._min[c];
    }
    private static void setMax(G g, long v, int c) {
      double o = g._max[c];
      double vv = Double.longBitsToDouble(v);
      while( vv > o && !G.CAS_max(g, G.doubleRawIdx(c), Double.doubleToRawLongBits(o), v))
        o = g._max[c];
    }
    private static void setSum(G g, long vv, int c) {
      double v = Double.longBitsToDouble(vv);
      double o = g._sum[c];
      while(!G.CAS_sum(g,G.doubleRawIdx(c),Double.doubleToRawLongBits(o),Double.doubleToRawLongBits(o+v)))
        o=g._sum[c];
    }
    private static void setSS(G g, long vv, int c, boolean isReduce) {
      double v = Double.longBitsToDouble(vv);
      double o = g._ss[c];
      if( isReduce ) {
        while(!G.CAS_ss(g,G.doubleRawIdx(c), Double.doubleToRawLongBits(o), Double.doubleToRawLongBits(o+v)))
          o = g._ss[c];
      } else {
        while (!G.CAS_ss(g, G.doubleRawIdx(c), Double.doubleToRawLongBits(o), Double.doubleToRawLongBits(o + v * v)))
          o = g._ss[c];
      }
    }
    private static void setNA(G g, long n, int c) {
      long o = g._NA[c];
      while(!G.CAS_NA(g,G.longRawIdx(c),o,o+n))
        o=g._NA[c];
    }
  }

  private static class GTask extends H2O.H2OCountedCompleter {
    private final G _g;
    private final long[] _orderByCols;
    private NonBlockingHashSet[] _modeDomain;
    GTask(H2O.H2OCountedCompleter cc, G g,long[] orderByCols, NonBlockingHashSet[] modeDomain) { super(cc); _g=g; _orderByCols=orderByCols; _modeDomain=modeDomain; }
    @Override protected void compute2() {
      _g.close();
      int[] orderByCols = _orderByCols==null?null:new int[_orderByCols.length];
      if( orderByCols != null )
        for(int i=0;i {
    private final G[] _g;
    private final int _ngrps;
    private final long[] _orderByCols;
    private final int _maxP=50*1000; // burn 50K at a time
    private final AtomicInteger _ctr;
    private NonBlockingHashSet[] _modeDomain;
    ParallelPostGlobal(G[] g, int ngrps, long[] orderByCols) { _g=g; _ctr=new AtomicInteger(_maxP-1); _ngrps=ngrps; _orderByCols = orderByCols;
      _modeDomain=_g[0]._aggs==null ? null :new NonBlockingHashSet[_g[0]._aggs.length];
      if( _modeDomain!=null )
        for( int i=0;i<_modeDomain.length;++i ) _modeDomain[i] = new NonBlockingHashSet<>();
    }

    @Override protected void compute2(){
      addToPendingCount(_g.length-1);
      for( int i=0;i();
        int s = ab.get4();
        if( s==0 ) continue;
        for(int j=0;j {
    public int[] _orderByCols;  // set during the ParallelPostGlobal if there is to be any order by
    public final double _ds[];  // Array is final; contents change with the "fill"
    public int _hash;           // Hash is not final; changes with the "fill"
    private AGG[] _aggs;        // the aggs
    public G fill(int row, Chunk chks[], long cols[]) {
      for( int c=0; c>>20) ^ (h>>>12);
      h ^= (h>>> 7) ^ (h>>> 4);
      return (int)((h^(h>>32))&0x7FFFFFFF);
    }
    @Override public boolean equals( Object o ) {
      return o instanceof G && Arrays.equals(_ds, ((G) o)._ds); }
    @Override public int hashCode() { return _hash; }
    @Override public String toString() { return Arrays.toString(_ds); }

    // compare 2 groups
    // iterate down _ds, stop when _ds[i] > that._ds[i], or _ds[i] < that._ds[i]
    // order by various columns specified by _orderByCols
    // NaN is treated as least
    @Override public int compareTo(G g) {
      for(int i:_orderByCols)
        if(      Double.isNaN(_ds[i])   || _ds[i] < g._ds[i] ) return -1;
        else if( Double.isNaN(g._ds[i]) || _ds[i] > g._ds[i] ) return 1;
      return 0;
    }

    public long     _N;         // number of rows in the group, updated atomically
    public long[]   _ND;        // count of distincts, built from the NBHS
    public long[]   _NA;        // count of NAs for each aggregate, updated atomically
    public long[]   _f;         // first row, updated atomically
    public long[]   _l;         // last row, atomically updated
    public double[] _min;       // updated atomically
    public double[] _max;       // updated atomically
    public double[] _sum;       // sum, updated atomically
    public double[] _ss;        // sum of squares, updated atomically
    public double[] _avs;       // means, computed in the close
    public double[] _vars;      // vars, computed in the close
    public double[] _sdevs;     // sds,  computed in the close
    public long[/*aggs*/][/*level cnts*/] _m;       // atomically aggregate counts of each domain
    public String[] _mode;      // finalize the _m array into _mode
//    private NBHSAD _nd;         // count distinct helper data structure
    private byte[] _NAMethod;

    // offset crud for unsafe
    private static final Unsafe U = UtilUnsafe.getUnsafe();
    private static final long _NOffset;

    // long[] offset and scale
    private static final int _8B = U.arrayBaseOffset(long[].class);
    private static final int _8S = U.arrayIndexScale(long[].class);
    // double[] offset and scale
    private static final int _dB = U.arrayBaseOffset(double[].class);
    private static final int _dS = U.arrayIndexScale(double[].class);

    // get the raw indices for the long[] and double[]
    private static long longRawIdx(int i)   { return _8B + _8S * i; }
    private static long doubleRawIdx(int i) { return _dB + _dS * i; }

    static {
      try {
        _NOffset   = U.objectFieldOffset(G.class.getDeclaredField("_N"));
      } catch(Exception e) { throw H2O.fail(); }
    }

    G(int row, Chunk[] cs, long[] cols,int aggs, byte[] naMethod) {
      this(cols.length,aggs,naMethod);
      fill(row, cs, cols);
    }

    G(int len, AGG[] aggs) {
      _aggs=aggs;
      _ds=new double[len];
      _NAMethod=AGG.naMethods(aggs);
      _NA=new long[aggs.length];
      byte[] types = AGG.types(aggs);
      for(byte t: types) {
        switch(t) {
          case AGG.T_ND: _ND=new long[aggs.length]; break;
          case AGG.T_F:  _f =new long[aggs.length]; break;
          case AGG.T_L:  _l =new long[aggs.length]; break;
          case AGG.T_MIN:
            _min=new double[aggs.length];
            for( int i=0; i<_min.length; ++i) _min[i]=Double.POSITIVE_INFINITY;
            break;
          case AGG.T_MAX:
            _max=new double[aggs.length];
            for( int i=0; i<_max.length; ++i) _max[i]=Double.NEGATIVE_INFINITY;
            break;
          case AGG.T_SUM:_sum=new double[aggs.length]; break;
          case AGG.T_SS:  _ss=new double[aggs.length]; break;
          case AGG.T_MODE:
            _m=new long[aggs.length][];
            for( int i=0;i<_m.length; ++i )
              _m[i] = aggs[i]._domainsForMode==null?null:new long[aggs[i]._domainsForMode.length];
            break;
        }
      }
    }

    // deprecated... much better to use above constructor
    G(int len, int aggs, byte[] naMethod) {
      _ds=new double[len];
      _NAMethod=naMethod;
//      _nd=new NBHSAD(aggs);
      _ND=new long[aggs];
      _NA=new long[aggs];
      _f =new long[aggs];
      _l =new long[aggs];
      _min=new double[aggs];
      _max=new double[aggs];
      _sum=new double[aggs];
      _ss =new double[aggs];
      _avs=new double[aggs];
      _vars=new double[aggs];
      _sdevs=new double[aggs];
      _mode=new String[aggs];
      _m=new long[aggs][];
      for( int i=0; i<_min.length; ++i) _min[i]=Double.POSITIVE_INFINITY;
      for( int i=0; i<_max.length; ++i) _max[i]=Double.NEGATIVE_INFINITY;
    }

    G(int len) {_ds=new double[len];}
    G(){ _ds=null;}
    G(double[] ds) { _ds=ds; }

    private void close() {
      _avs  = _sum ==null?null:new double[_NAMethod.length];
      _vars = (_sum ==null || _ss==null)?null:new double[_avs.length];
      _sdevs=_vars ==null?null:new double[_vars.length];
      for( int i=0;i<_NAMethod.length;++i ) {
        long n = _NAMethod[i]==AGG.T_RM?_N-_NA[i]:_N;
        if(_avs!=null)    _avs[i] = _sum[i]/n;
//        _ND[i] = _nd._nd[i]==null?0:_nd._nd[i].size(); _nd._nd[i]=null; // b free!
        if(_vars!=null)   _vars[i] = (_ss[i] - (_sum[i]*_sum[i])/n)/n;
        if( _sdevs!=null) _sdevs[i]=Math.sqrt(_vars[i]);
      }
      if( _m!=null && _aggs!=null ) {
        _mode = new String[_NAMethod.length];
        for(int i=0; i < _m.length;++i )
          if( _m[i]!=null)
            _mode[i] = _aggs[i]._domainsForMode[ArrayUtils.maxIndex(_m[i])];
      }
    }

    protected static boolean CAS_N (G g, long o, long n        ) { return U.compareAndSwapLong(g,_NOffset,o,n); }
    private static boolean CAS_NA(G g, long off, long o, long n) { return U.compareAndSwapLong(g._NA,off,o,n);  }
    private static boolean CAS_f (G g, long off, long o, long n) { return U.compareAndSwapLong(g._f,off,o,n);   }
    private static boolean CAS_l (G g, long off, long o, long n) { return U.compareAndSwapLong(g._l,off,o,n);   }
    private static boolean CAS_m (G g, int c, int lvl,  long o, long n) { return U.compareAndSwapLong(g._m[c], G.longRawIdx(lvl),o,n); }

    // doubles are toRawLongBits'ized, and passed as longs
    private static boolean CAS_min(G g, long off, long o, long n) { return U.compareAndSwapLong(g._min,off,o,n);}
    private static boolean CAS_max(G g, long off, long o, long n) { return U.compareAndSwapLong(g._max,off,o,n);}
    private static boolean CAS_sum(G g, long off, long o, long n) { return U.compareAndSwapLong(g._sum,off,o,n);}
    private static boolean CAS_ss (G g, long off, long o, long n) { return U.compareAndSwapLong(g._ss ,off,o,n);}
  }

  static class AGG extends AST {
    @Override AGG make() { return new AGG(); }
    // (AGG "agg" #col "na"  "agg" #col "na"   => string num string   string num string
    String opStr() { return "agg";  }
    private AGG[] _aggs;
    AGG parse_impl(Exec E) {
      ArrayList aggs = new ArrayList<>();
      while( !E.isEnd() ) {
        String type = E.parseString(E.peekPlus());
        AST colast = E.parse();
        Integer col=null;
        AST delayedCol=null;
        String delayedColByName=null;
        if( colast instanceof ASTNum ) col = (int)((ASTNum)colast)._d;
        else if( colast instanceof ASTString ) delayedColByName = ((ASTString)colast)._s;
        else delayedCol = colast; // check for badness sometime later...
        String   na = E.parseString(E.peekPlus());
        String name = E.parseString(E.peekPlus());
        aggs.add(new AGG(type,col,na,name,delayedColByName,delayedCol));
      }
      _aggs = aggs.toArray(new AGG[aggs.size()]);
      E.eatEnd();
      return this;
    }

    // Aggregate types
    private static final byte T_N   = 0;
    private static final byte T_ND  = 1;
    private static final byte T_F   = 2;
    private static final byte T_L   = 3;
    private static final byte T_MIN = 4;
    private static final byte T_MAX = 5;
    private static final byte T_AVG = 6;
    private static final byte T_SD  = 7;
    private static final byte T_VAR = 8;
    private static final byte T_SUM = 9;
    private static final byte T_SS  = 10;
    private static final byte T_MODE= 11;

    // How to handle NAs
    private static final byte T_ALL = 0;
    private static final byte T_IG  = 1;
    private static final byte T_RM  = 2;

    private static transient HashMap TM = new HashMap<>();
    static{
      // aggregates
      TM.put("count",       (byte)0);
      TM.put("nrow",        (byte)0);
      TM.put("count_unique",(byte)1);
      TM.put("first",       (byte)2);
      TM.put("last",        (byte)3);
      TM.put("min",         (byte)4);
      TM.put("max",         (byte)5);
      TM.put("mean",        (byte)6);
      TM.put("avg",         (byte)6);
      TM.put("sd",          (byte)7);
      TM.put("stdev",       (byte)7);
      TM.put("var",         (byte)8);
      TM.put("sum",         (byte)9);
      TM.put("ss",          (byte)10);
      TM.put("mode",        (byte)11);
      TM.put("most",        (byte)11);
      // na handling
      TM.put("all"         ,(byte)0);
      TM.put("ignore"      ,(byte)1);
      TM.put("rm"          ,(byte)2);
    }

    private final byte _type;
    private Integer _c;
    private final String _name;
    private final byte _na_handle;
    private AST _delayedCol;
    private String _delayedColByName;
    private String[] _domainsForMode;
    AGG() {_type=0;_c=-1;_name=null;_na_handle=0;}
    AGG(String s, Integer c, String na, String name, String delayedColByName, AST delayedCol) {  // big I Integer allows for nullness
      _type=TM.get(s.toLowerCase());
      _c=c;
      _delayedCol = delayedCol;
      _delayedColByName = delayedColByName;
      _name=(name==null || name.equals(""))?s+"_C"+(c+1):name;
      if( !TM.keySet().contains(na) ) {
        Log.info("Unknown NA handle type given: `" + na + "`. Switching to \"ignore\" method.");
        _na_handle=0;
      } else _na_handle = TM.get(na);
    }

    private static String[] names(AGG[] _agg) {
      String[] names = new String[_agg.length];
      for(int i=0;i typesHS = new HashSet<>();
      for(AGG a: agg) {
        switch(a._type) {
          case T_AVG: typesHS.add(T_SUM); break;
          case T_VAR:
            typesHS.add(T_SUM);
            typesHS.add(T_SS);
            break;
          case T_SD:
            typesHS.add(T_SUM);
            typesHS.add(T_SS);
            break;
          default: typesHS.add(a._type);
        }
      }
      byte[] types = new byte[typesHS.size()];
      int i=0;
      for( byte b: typesHS) types[i++]=b;
      return types;
    }

    private boolean isIgnore() { return _na_handle == 0; }
    private boolean isRemove() { return _na_handle == 1; }
    private boolean isAll()    { return _na_handle == 2; }

    // satisfy the extends
    @Override void exec(Env e) { throw H2O.fail();}
    @Override String value() { return "agg"; }
    @Override int type() { return 0; }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy