
water.rapids.ASTGroupBy Maven / Gradle / Ivy
package water.rapids;
import sun.misc.Unsafe;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashSet;
import water.nbhm.UtilUnsafe;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.Log;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* GROUPBY: Single pass aggregation by columns.
*
* NA handling:
*
* AGG.T_IG: case 0
* Count NA rows, but discard values in sums, mins, maxs
* FIRST/LAST return the first nonNA first/last, or NA if all NA
*
* AGG.T_RM: case 1
* Count NA rows separately, discard values in sums, mins, maxs and compute aggregates less NA row counts
* FIRST/LAST treated as above
*
* AGG.T_ALL: case 2
* Include NA in all aggregates -- any NA encountered forces aggregate to be NA.
* FIRST/LAST return first/last row regardless of NAs.
*
* Aggregates:
* MIN
* MAX
* MEAN
* COUNT
* SUM
* SD
* VAR
* COUNT_DISTINCT
* FIRST
* LAST
* MODE
* Aggregations on time and numeric columns only.
*/
public class ASTGroupBy extends ASTUniPrefixOp {
// AST: (GB fr cols AGGS ORDERBY)
// (GB %k (llist #1;#3) (AGGS #2 "min" #4 "mean" #6) ()) for no order by..., otherwise is a llist or single number
private long[] _gbCols; // group by columns
private AGG[] _agg;
private AST[] _gbColsDelayed;
private String[] _gbColsDelayedByName;
private long[] _orderByCols;
ASTGroupBy() { super(null); }
@Override String opStr() { return "GB"; }
@Override ASTOp make() {return new ASTGroupBy();}
ASTGroupBy parse_impl(Exec E) {
AST ary = E.parse();
// parse gby columns
AST s = E.parse();
if( s instanceof ASTLongList ) _gbCols = ((ASTLongList)s)._l;
else if( s instanceof ASTNum ) _gbCols = new long[]{(long)((ASTNum)s)._d};
else if( s instanceof ASTAry ) _gbColsDelayed = ((ASTAry)s)._a;
else if( s instanceof ASTStringList) _gbColsDelayedByName = ((ASTStringList)s)._s;
else if( s instanceof ASTDoubleList ) {
double[] d = ((ASTDoubleList)s)._d;
_gbCols = new long[d.length];
for(int i=0;i()));
if( e.isAry() ) res[i++] = f.find(e.popAry().anyVec());
else if( e.isNum() ) res[i++] = (int)e.popDbl();
else if( e.isStr() ) res[i++] = f.find(e.popStr());
else throw new IllegalArgumentException("Don't know what to do with: " + ast.getClass() + "; " + e.pop());
}
return res;
}
private void computeCols(AGG[] aggs, Frame f) {
for(AGG a:aggs ) {
if( a._c == null ) {
if( a._delayedColByName!=null ) a._c = f.find(a._delayedColByName);
else if( a._delayedCol!=null ) {
Env e = treeWalk(new Env(new HashSet()));
if( e.isAry() ) a._c = f.find(e.popAry().anyVec());
else if( e.isNum() ) a._c = (int)e.popDbl();
else if( e.isStr() ) a._c = f.find(e.popStr());
else throw new IllegalArgumentException("No column found for: " + e.pop());
}
else throw new IllegalArgumentException("Missing column for aggregate: " + a._name);
}
}
}
public static class IcedNBHS extends Iced implements Iterable {
NonBlockingHashSet _g;
IcedNBHS() {_g=new NonBlockingHashSet<>();}
boolean add(T t) { return _g.add(t); }
boolean addAll(Collection extends T> c) { return _g.addAll(c); }
T get(T g) { return _g.get(g); }
int size() { return _g.size(); }
@Override public AutoBuffer write_impl( AutoBuffer ab ) {
if( _g == null ) return ab.put4(0);
ab.put4(_g.size());
for( T g: _g ) ab.put(g);
return ab;
}
@Override public IcedNBHS read_impl(AutoBuffer ab) {
int len = ab.get4();
if( len == 0 ) return this;
_g = new NonBlockingHashSet<>();
for( int i=0;i iterator() {return _g.iterator(); }
}
public static class GBTask extends MRTask {
IcedHashMap _g;
private long[] _gbCols;
private AGG[] _agg;
GBTask(long[] gbCols, AGG[] agg) { _gbCols=gbCols; _agg=agg; }
@Override public void setupLocal() {_g=new IcedHashMap<>();}
@Override public void map(Chunk[] c) {
long start = c[0].start();
G g = new G(_gbCols.length,_agg);
G gOld; // fill this one in for all the CAS'ing
for( int i=0;i l = _g;
IcedHashMap r = t._g;
if( l.size() < r.size() ) { l=r; r=_g; } // larger on the left
// loop over the smaller set of grps
for( G rg:r.keySet() ) {
G lg = l.getk(rg);
if( l.putIfAbsent(rg,"")!=null ) {
assert lg!=null;
long R = lg._N;
while (!G.CAS_N(lg, R, R + rg._N))
R = lg._N;
reduceGroup(_agg, lg, rg);
}
}
_g=l;
t._g=null;
}
}
// task helper functions
private static void perRow(AGG[] agg, int chkRow, long rowOffset, Chunk[] c, G g) { perRow(agg,chkRow,rowOffset,c,g,null); }
private static void reduceGroup(AGG[] agg, G g, G that) { perRow(agg,-1,-1,null,g,that);}
private static void perRow(AGG[] agg, int chkRow, long rowOffset, Chunk[] c, G g, G that) {
byte type; int col;
for( int i=0;i o && !G.CAS_l(g, G.longRawIdx(c), o, v))
o = g._l[c];
}
private static void setMin(G g, long v, int c) {
double o = g._min[c];
double vv = Double.longBitsToDouble(v);
while( vv < o && !G.CAS_min(g,G.doubleRawIdx(c),Double.doubleToRawLongBits(o),v))
o = g._min[c];
}
private static void setMax(G g, long v, int c) {
double o = g._max[c];
double vv = Double.longBitsToDouble(v);
while( vv > o && !G.CAS_max(g, G.doubleRawIdx(c), Double.doubleToRawLongBits(o), v))
o = g._max[c];
}
private static void setSum(G g, long vv, int c) {
double v = Double.longBitsToDouble(vv);
double o = g._sum[c];
while(!G.CAS_sum(g,G.doubleRawIdx(c),Double.doubleToRawLongBits(o),Double.doubleToRawLongBits(o+v)))
o=g._sum[c];
}
private static void setSS(G g, long vv, int c, boolean isReduce) {
double v = Double.longBitsToDouble(vv);
double o = g._ss[c];
if( isReduce ) {
while(!G.CAS_ss(g,G.doubleRawIdx(c), Double.doubleToRawLongBits(o), Double.doubleToRawLongBits(o+v)))
o = g._ss[c];
} else {
while (!G.CAS_ss(g, G.doubleRawIdx(c), Double.doubleToRawLongBits(o), Double.doubleToRawLongBits(o + v * v)))
o = g._ss[c];
}
}
private static void setNA(G g, long n, int c) {
long o = g._NA[c];
while(!G.CAS_NA(g,G.longRawIdx(c),o,o+n))
o=g._NA[c];
}
}
private static class GTask extends H2O.H2OCountedCompleter {
private final G _g;
private final long[] _orderByCols;
private NonBlockingHashSet[] _modeDomain;
GTask(H2O.H2OCountedCompleter cc, G g,long[] orderByCols, NonBlockingHashSet[] modeDomain) { super(cc); _g=g; _orderByCols=orderByCols; _modeDomain=modeDomain; }
@Override protected void compute2() {
_g.close();
int[] orderByCols = _orderByCols==null?null:new int[_orderByCols.length];
if( orderByCols != null )
for(int i=0;i {
private final G[] _g;
private final int _ngrps;
private final long[] _orderByCols;
private final int _maxP=50*1000; // burn 50K at a time
private final AtomicInteger _ctr;
private NonBlockingHashSet[] _modeDomain;
ParallelPostGlobal(G[] g, int ngrps, long[] orderByCols) { _g=g; _ctr=new AtomicInteger(_maxP-1); _ngrps=ngrps; _orderByCols = orderByCols;
_modeDomain=_g[0]._aggs==null ? null :new NonBlockingHashSet[_g[0]._aggs.length];
if( _modeDomain!=null )
for( int i=0;i<_modeDomain.length;++i ) _modeDomain[i] = new NonBlockingHashSet<>();
}
@Override protected void compute2(){
addToPendingCount(_g.length-1);
for( int i=0;i();
int s = ab.get4();
if( s==0 ) continue;
for(int j=0;j {
public int[] _orderByCols; // set during the ParallelPostGlobal if there is to be any order by
public final double _ds[]; // Array is final; contents change with the "fill"
public int _hash; // Hash is not final; changes with the "fill"
private AGG[] _aggs; // the aggs
public G fill(int row, Chunk chks[], long cols[]) {
for( int c=0; c>>20) ^ (h>>>12);
h ^= (h>>> 7) ^ (h>>> 4);
return (int)((h^(h>>32))&0x7FFFFFFF);
}
@Override public boolean equals( Object o ) {
return o instanceof G && Arrays.equals(_ds, ((G) o)._ds); }
@Override public int hashCode() { return _hash; }
@Override public String toString() { return Arrays.toString(_ds); }
// compare 2 groups
// iterate down _ds, stop when _ds[i] > that._ds[i], or _ds[i] < that._ds[i]
// order by various columns specified by _orderByCols
// NaN is treated as least
@Override public int compareTo(G g) {
for(int i:_orderByCols)
if( Double.isNaN(_ds[i]) || _ds[i] < g._ds[i] ) return -1;
else if( Double.isNaN(g._ds[i]) || _ds[i] > g._ds[i] ) return 1;
return 0;
}
public long _N; // number of rows in the group, updated atomically
public long[] _ND; // count of distincts, built from the NBHS
public long[] _NA; // count of NAs for each aggregate, updated atomically
public long[] _f; // first row, updated atomically
public long[] _l; // last row, atomically updated
public double[] _min; // updated atomically
public double[] _max; // updated atomically
public double[] _sum; // sum, updated atomically
public double[] _ss; // sum of squares, updated atomically
public double[] _avs; // means, computed in the close
public double[] _vars; // vars, computed in the close
public double[] _sdevs; // sds, computed in the close
public long[/*aggs*/][/*level cnts*/] _m; // atomically aggregate counts of each domain
public String[] _mode; // finalize the _m array into _mode
// private NBHSAD _nd; // count distinct helper data structure
private byte[] _NAMethod;
// offset crud for unsafe
private static final Unsafe U = UtilUnsafe.getUnsafe();
private static final long _NOffset;
// long[] offset and scale
private static final int _8B = U.arrayBaseOffset(long[].class);
private static final int _8S = U.arrayIndexScale(long[].class);
// double[] offset and scale
private static final int _dB = U.arrayBaseOffset(double[].class);
private static final int _dS = U.arrayIndexScale(double[].class);
// get the raw indices for the long[] and double[]
private static long longRawIdx(int i) { return _8B + _8S * i; }
private static long doubleRawIdx(int i) { return _dB + _dS * i; }
static {
try {
_NOffset = U.objectFieldOffset(G.class.getDeclaredField("_N"));
} catch(Exception e) { throw H2O.fail(); }
}
G(int row, Chunk[] cs, long[] cols,int aggs, byte[] naMethod) {
this(cols.length,aggs,naMethod);
fill(row, cs, cols);
}
G(int len, AGG[] aggs) {
_aggs=aggs;
_ds=new double[len];
_NAMethod=AGG.naMethods(aggs);
_NA=new long[aggs.length];
byte[] types = AGG.types(aggs);
for(byte t: types) {
switch(t) {
case AGG.T_ND: _ND=new long[aggs.length]; break;
case AGG.T_F: _f =new long[aggs.length]; break;
case AGG.T_L: _l =new long[aggs.length]; break;
case AGG.T_MIN:
_min=new double[aggs.length];
for( int i=0; i<_min.length; ++i) _min[i]=Double.POSITIVE_INFINITY;
break;
case AGG.T_MAX:
_max=new double[aggs.length];
for( int i=0; i<_max.length; ++i) _max[i]=Double.NEGATIVE_INFINITY;
break;
case AGG.T_SUM:_sum=new double[aggs.length]; break;
case AGG.T_SS: _ss=new double[aggs.length]; break;
case AGG.T_MODE:
_m=new long[aggs.length][];
for( int i=0;i<_m.length; ++i )
_m[i] = aggs[i]._domainsForMode==null?null:new long[aggs[i]._domainsForMode.length];
break;
}
}
}
// deprecated... much better to use above constructor
G(int len, int aggs, byte[] naMethod) {
_ds=new double[len];
_NAMethod=naMethod;
// _nd=new NBHSAD(aggs);
_ND=new long[aggs];
_NA=new long[aggs];
_f =new long[aggs];
_l =new long[aggs];
_min=new double[aggs];
_max=new double[aggs];
_sum=new double[aggs];
_ss =new double[aggs];
_avs=new double[aggs];
_vars=new double[aggs];
_sdevs=new double[aggs];
_mode=new String[aggs];
_m=new long[aggs][];
for( int i=0; i<_min.length; ++i) _min[i]=Double.POSITIVE_INFINITY;
for( int i=0; i<_max.length; ++i) _max[i]=Double.NEGATIVE_INFINITY;
}
G(int len) {_ds=new double[len];}
G(){ _ds=null;}
G(double[] ds) { _ds=ds; }
private void close() {
_avs = _sum ==null?null:new double[_NAMethod.length];
_vars = (_sum ==null || _ss==null)?null:new double[_avs.length];
_sdevs=_vars ==null?null:new double[_vars.length];
for( int i=0;i<_NAMethod.length;++i ) {
long n = _NAMethod[i]==AGG.T_RM?_N-_NA[i]:_N;
if(_avs!=null) _avs[i] = _sum[i]/n;
// _ND[i] = _nd._nd[i]==null?0:_nd._nd[i].size(); _nd._nd[i]=null; // b free!
if(_vars!=null) _vars[i] = (_ss[i] - (_sum[i]*_sum[i])/n)/n;
if( _sdevs!=null) _sdevs[i]=Math.sqrt(_vars[i]);
}
if( _m!=null && _aggs!=null ) {
_mode = new String[_NAMethod.length];
for(int i=0; i < _m.length;++i )
if( _m[i]!=null)
_mode[i] = _aggs[i]._domainsForMode[ArrayUtils.maxIndex(_m[i])];
}
}
protected static boolean CAS_N (G g, long o, long n ) { return U.compareAndSwapLong(g,_NOffset,o,n); }
private static boolean CAS_NA(G g, long off, long o, long n) { return U.compareAndSwapLong(g._NA,off,o,n); }
private static boolean CAS_f (G g, long off, long o, long n) { return U.compareAndSwapLong(g._f,off,o,n); }
private static boolean CAS_l (G g, long off, long o, long n) { return U.compareAndSwapLong(g._l,off,o,n); }
private static boolean CAS_m (G g, int c, int lvl, long o, long n) { return U.compareAndSwapLong(g._m[c], G.longRawIdx(lvl),o,n); }
// doubles are toRawLongBits'ized, and passed as longs
private static boolean CAS_min(G g, long off, long o, long n) { return U.compareAndSwapLong(g._min,off,o,n);}
private static boolean CAS_max(G g, long off, long o, long n) { return U.compareAndSwapLong(g._max,off,o,n);}
private static boolean CAS_sum(G g, long off, long o, long n) { return U.compareAndSwapLong(g._sum,off,o,n);}
private static boolean CAS_ss (G g, long off, long o, long n) { return U.compareAndSwapLong(g._ss ,off,o,n);}
}
static class AGG extends AST {
@Override AGG make() { return new AGG(); }
// (AGG "agg" #col "na" "agg" #col "na" => string num string string num string
String opStr() { return "agg"; }
private AGG[] _aggs;
AGG parse_impl(Exec E) {
ArrayList aggs = new ArrayList<>();
while( !E.isEnd() ) {
String type = E.parseString(E.peekPlus());
AST colast = E.parse();
Integer col=null;
AST delayedCol=null;
String delayedColByName=null;
if( colast instanceof ASTNum ) col = (int)((ASTNum)colast)._d;
else if( colast instanceof ASTString ) delayedColByName = ((ASTString)colast)._s;
else delayedCol = colast; // check for badness sometime later...
String na = E.parseString(E.peekPlus());
String name = E.parseString(E.peekPlus());
aggs.add(new AGG(type,col,na,name,delayedColByName,delayedCol));
}
_aggs = aggs.toArray(new AGG[aggs.size()]);
E.eatEnd();
return this;
}
// Aggregate types
private static final byte T_N = 0;
private static final byte T_ND = 1;
private static final byte T_F = 2;
private static final byte T_L = 3;
private static final byte T_MIN = 4;
private static final byte T_MAX = 5;
private static final byte T_AVG = 6;
private static final byte T_SD = 7;
private static final byte T_VAR = 8;
private static final byte T_SUM = 9;
private static final byte T_SS = 10;
private static final byte T_MODE= 11;
// How to handle NAs
private static final byte T_ALL = 0;
private static final byte T_IG = 1;
private static final byte T_RM = 2;
private static transient HashMap TM = new HashMap<>();
static{
// aggregates
TM.put("count", (byte)0);
TM.put("nrow", (byte)0);
TM.put("count_unique",(byte)1);
TM.put("first", (byte)2);
TM.put("last", (byte)3);
TM.put("min", (byte)4);
TM.put("max", (byte)5);
TM.put("mean", (byte)6);
TM.put("avg", (byte)6);
TM.put("sd", (byte)7);
TM.put("stdev", (byte)7);
TM.put("var", (byte)8);
TM.put("sum", (byte)9);
TM.put("ss", (byte)10);
TM.put("mode", (byte)11);
TM.put("most", (byte)11);
// na handling
TM.put("all" ,(byte)0);
TM.put("ignore" ,(byte)1);
TM.put("rm" ,(byte)2);
}
private final byte _type;
private Integer _c;
private final String _name;
private final byte _na_handle;
private AST _delayedCol;
private String _delayedColByName;
private String[] _domainsForMode;
AGG() {_type=0;_c=-1;_name=null;_na_handle=0;}
AGG(String s, Integer c, String na, String name, String delayedColByName, AST delayedCol) { // big I Integer allows for nullness
_type=TM.get(s.toLowerCase());
_c=c;
_delayedCol = delayedCol;
_delayedColByName = delayedColByName;
_name=(name==null || name.equals(""))?s+"_C"+(c+1):name;
if( !TM.keySet().contains(na) ) {
Log.info("Unknown NA handle type given: `" + na + "`. Switching to \"ignore\" method.");
_na_handle=0;
} else _na_handle = TM.get(na);
}
private static String[] names(AGG[] _agg) {
String[] names = new String[_agg.length];
for(int i=0;i typesHS = new HashSet<>();
for(AGG a: agg) {
switch(a._type) {
case T_AVG: typesHS.add(T_SUM); break;
case T_VAR:
typesHS.add(T_SUM);
typesHS.add(T_SS);
break;
case T_SD:
typesHS.add(T_SUM);
typesHS.add(T_SS);
break;
default: typesHS.add(a._type);
}
}
byte[] types = new byte[typesHS.size()];
int i=0;
for( byte b: typesHS) types[i++]=b;
return types;
}
private boolean isIgnore() { return _na_handle == 0; }
private boolean isRemove() { return _na_handle == 1; }
private boolean isAll() { return _na_handle == 2; }
// satisfy the extends
@Override void exec(Env e) { throw H2O.fail();}
@Override String value() { return "agg"; }
@Override int type() { return 0; }
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy