All Downloads are FREE. Search and download functionalities are using the official Maven repository.

water.rapids.ASTGroupedPermute Maven / Gradle / Ivy

There is a newer version: 3.8.2.9
Show newest version
package water.rapids;

import water.H2O;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.IcedHashMap;
import water.util.Log;

import java.util.HashMap;


public class ASTGroupedPermute extends ASTPrim {
  //   .newExpr("grouped_permute", fr, permCol, permByCol, groupByCols, keepCol)
  @Override public String[] args() { return new String[]{"ary", "permCol", "groupBy", "permuteBy", "keepCol"}; }  // currently only allow 2 items in permuteBy
  @Override int nargs() { return 1 + 5; } // (trim x col groupBy permuteBy keepCol)
  @Override public String str() { return "grouped_permute"; }
  @Override
  public Val apply(Env env, Env.StackHelp stk, AST asts[]) {
    Frame fr = stk.track(asts[1].exec(env)).getFrame();
    final int permCol = (int) asts[2].exec(env).getNum();
    ASTNumList groupby = ASTGroup.check(fr.numCols(), asts[3]);
    final int[] gbCols = groupby.expand4();
    final int permuteBy = (int) asts[4].exec(env).getNum();
    final int keepCol = (int) asts[5].exec(env).getNum();

    String[] names = new String[gbCols.length + 4];
    int i = 0;
    for (; i < gbCols.length; ++i)
      names[i] = fr.name(gbCols[i]);
    names[i++] = "In";
    names[i++] = "Out";
    names[i++] = "InAmnt";
    names[i] = "OutAmnt";

    String[][] domains = new String[names.length][];
    int d = 0;
    for (; d < gbCols.length; d++)
      domains[d] = fr.domains()[gbCols[d]];
    domains[d++] = fr.domains()[permCol];
    domains[d++] = fr.domains()[permCol];
    domains[d++] = fr.domains()[keepCol];
    domains[d] = fr.domains()[keepCol];
    long s = System.currentTimeMillis();
    BuildGroups t = new BuildGroups(gbCols, permuteBy, permCol, keepCol).doAll(fr);
    Log.info("Elapsed time: " + (System.currentTimeMillis() - s) / 1000. + "s");
    s = System.currentTimeMillis();
    SmashGroups sg;
    H2O.submitTask(sg=new SmashGroups(t._grps)).join();
    Log.info("Elapsed time: " + (System.currentTimeMillis() - s) / 1000. + "s");
    return new ValFrame(buildOutput(sg._res.values().toArray(new double[0][][]), names, domains));
  }

  private static Frame buildOutput(final double[][][] a, String[] names, String[][] domains) {
    Frame dVec = new Frame(Vec.makeSeq(0, a.length));
    long s = System.currentTimeMillis();
    Frame res = new MRTask() {
      @Override public void map(Chunk[] cs, NewChunk[] ncs) {
        for(int i=0;i {
    IcedHashMap[]> _grps;  // shared per node (all grps with permutations atomically inserted)

    private final int _gbCols[];
    private final int _permuteBy;
    private final int _permuteCol;
    private final int _amntCol;

    BuildGroups(int[] gbCols, int permuteBy, int permuteCol, int amntCol) {
      _gbCols = gbCols;
      _permuteBy = permuteBy;
      _permuteCol = permuteCol;
      _amntCol = amntCol;
    }

    @Override public void setupLocal() { _grps = new IcedHashMap<>(); }
    @Override public void map(Chunk[] chks) {
      String[] dom = chks[_permuteBy].vec().domain();
      IcedHashMap[]> grps = new IcedHashMap<>();
      for (int row = 0; row < chks[0]._len; ++row) {
        long jid = chks[_gbCols[0]].at8(row);
        long rid = chks[_permuteCol].at8(row);
        double[] aci = new double[]{rid,  chks[_amntCol].atd(row)};
        int type = dom[(int) chks[_permuteBy].at8(row)].equals("D") ? 0 : 1;
        if( grps.containsKey(jid) ) {
          IcedHashMap[] dcWork = grps.get(jid);
          if(dcWork[type].putIfAbsent(rid, aci)!=null)
            dcWork[type].get(rid)[1] += aci[1];
        } else {
          IcedHashMap[] dcAcnts = new IcedHashMap[2];
          dcAcnts[0] = new IcedHashMap<>();
          dcAcnts[1] = new IcedHashMap<>();
          dcAcnts[type].put(rid, aci);
          grps.put(jid,dcAcnts);
        }
      }
      reduce(grps);
    }

    @Override public void reduce(BuildGroups t) { if (_grps != t._grps) reduce(t._grps); }
    private void reduce(IcedHashMap[]> r) {
      for (Long l : r.keySet()) {
        if (_grps.putIfAbsent(l, r.get(l)) != null) {
          IcedHashMap[] rdbls = r.get(l);
          IcedHashMap[] ldbls = _grps.get(l);

          for(Long rr: rdbls[0].keySet())
            if( ldbls[0].putIfAbsent(rr, rdbls[0].get(rr))!=null)
              ldbls[0].get(rr)[1]+=rdbls[0].get(rr)[1];

          for(Long rr: rdbls[1].keySet())
            if( ldbls[1].putIfAbsent(rr, rdbls[1].get(rr))!=null)
              ldbls[1].get(rr)[1]+=rdbls[1].get(rr)[1];
        }
      }
    }
  }

  private static class SmashGroups extends H2O.H2OCountedCompleter {
    private final IcedHashMap[]> _grps;
    private final HashMap _map;
    private int _hi;
    private int _lo;

    SmashGroups _left;
    SmashGroups _rite;

    private IcedHashMap _res;

    SmashGroups(IcedHashMap[]> grps) {
      _grps = grps;
      _lo = 0;
      _hi = _grps.size();
      _res = new IcedHashMap<>();
      _map = new HashMap<>();
      int i=0;
      for(Long l: _grps.keySet())
        _map.put(i++,l);
    }

    @Override public void compute2() {
      assert _left == null && _rite == null;
      if ((_hi - _lo) >= 2) { // divide/conquer down to 1 IHM
        final int mid = (_lo + _hi) >>> 1; // Mid-point
        _left = copyAndInit();
        _rite = copyAndInit();
        _left._hi = mid;          // Reset mid-point
        _rite._lo = mid;          // Also set self mid-point
        addToPendingCount(1);     // One fork awaiting completion
        _left.fork();             // Runs in another thread/FJ instance
        _rite.compute2();         // Runs in THIS F/J thread
        return;
      }
      if( _hi > _lo ) {
        smash();
      }
      tryComplete();
    }

    private void smash() {
      long key = _map.get(_lo);
      IcedHashMap[] pair = _grps.get(key);
      double[][] res = new double[pair[0].size() * pair[1].size()][]; // all combos
      int d0=0;
      for( double[] ds0: pair[0].values()) {
        for(double[] ds1: pair[1].values())
          res[d0++] = new double[]{key, ds0[0], ds1[0], ds0[1], ds1[1]};
      }
      _res.put(key, res);
    }
    private SmashGroups copyAndInit() {
      SmashGroups x = SmashGroups.this.clone();
      x.setCompleter(this);
      x. _left = x. _rite = null;
      x.setPendingCount(0);
      return x;
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy