All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.FrameTask Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex;

import water.*;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.ArrayUtils;
import water.util.RandomUtils;

import java.util.Arrays;
import java.util.Random;

public abstract class FrameTask> extends MRTask{
  protected boolean _sparse;
  protected transient DataInfo _dinfo;
  public DataInfo dinfo() { return _dinfo; }
  final Key _dinfoKey;
  final int [] _activeCols;
  final protected Key _jobKey;

  protected float _useFraction = 1.0f;
  private final long _seed;
  protected boolean _shuffle = false;
  private final int _iteration;

  public FrameTask(Key jobKey, DataInfo dinfo) {
    this(jobKey, dinfo, 0xDECAFBEE, -1, false);
  }
  public FrameTask(Key jobKey, DataInfo dinfo, long seed, int iteration, boolean sparse) {
    this(jobKey,dinfo==null?null:dinfo._key,dinfo==null?null:dinfo._activeCols,seed,iteration, sparse,null);
  }
  public FrameTask(Key jobKey, DataInfo dinfo, long seed, int iteration, boolean sparse, H2O.H2OCountedCompleter cmp) {
    this(jobKey,dinfo==null?null:dinfo._key,dinfo==null?null:dinfo._activeCols,seed,iteration, sparse,cmp);
  }
  private FrameTask(Key jobKey, Key dinfoKey, int [] activeCols,long seed, int iteration, boolean sparse, H2O.H2OCountedCompleter cmp) {
    super(cmp);
    _jobKey = jobKey;
    _dinfoKey = dinfoKey;
    _activeCols = activeCols;
    _seed = seed;
    _iteration = iteration;
    _sparse = sparse;
  }
  @Override protected void setupLocal(){
    DataInfo dinfo = DKV.get(_dinfoKey).get();
    _dinfo = _activeCols == null?dinfo:dinfo.filterExpandedColumns(_activeCols);
  }
  @Override protected void closeLocal(){ _dinfo = null;}

  /**
   * Method to process one row of the data. See for separate mini-batch logic below.
   * Numeric and categorical values are passed separately, as is response.
   * Categoricals are passed as absolute indexes into the expanded beta vector, 0-levels are skipped
   * (so the number of passed categoricals will not be the same for every row).
   *
   * Categorical expansion/indexing:
   *   Categoricals are placed in the beginning of the beta vector.
   *   Each cat variable with n levels is expanded into n-1 independent binary variables.
   *   Indexes in cats[] will point to the appropriate coefficient in the beta vector, so e.g.
   *   assume we have 2 categorical columns both with values A,B,C, then the following rows will have following indexes:
   *      A,A - ncats = 0, we do not pass any categorical here
   *      A,B - ncats = 1, indexes = [2]
   *      B,B - ncats = 2, indexes = [0,2]
   *      and so on
   *
   * @param gid      - global id of this row, in [0,_adaptedFrame.numRows())
   */
  protected void processRow(long gid, DataInfo.Row r){throw new RuntimeException("should've been overridden!");}
  protected void processRow(long gid, DataInfo.Row r, NewChunk [] outputs){throw new RuntimeException("should've been overridden!");}

  // mini-batch version - for DL only for now
  protected void processRow(long gid, DataInfo.Row r, int mb){throw new RuntimeException("should've been overridden!");}

  protected boolean skipRow(long gid) { return false; }

  /**
   * Mini-Batch update of model parameters
   * @param seed
   * @param responses
   * @param offsets
   * @param n actual number of rows in this minibatch
   */
  protected void processMiniBatch(long seed, double[] responses, double[] offsets, int n){}

  /**
   * Note: If this is overridden, then applyMiniBatch must be overridden as well to perform the model/weight mini-batch update
   * @return Return the mini-batch size
   */
  protected int getMiniBatchSize(){ return 0; }

  /**
   * Override this to initialize at the beginning of chunk processing.
   * @return whether or not to process this chunk
   */
  protected boolean chunkInit(){ return true; }
  /**
   * Override this to do post-chunk processing work.
   * @param n Number of processed rows
   */
  protected void chunkDone(long n){}

  /**
   * Extracts the values, applies regularization to numerics, adds appropriate offsets to categoricals,
   * and adapts response according to the CaseMode/CaseValue if set.
   */
  @Override public void map(Chunk [] chunks, NewChunk [] outputs) {
    if(_jobKey != null && _jobKey.get() != null && _jobKey.get().stop_requested()) throw new Job.JobCancelledException();
    final int nrows = chunks[0]._len;
    final long offset = chunks[0].start();
    boolean doWork = chunkInit();
    if (!doWork) return;
    final boolean obs_weights = _dinfo._weights
            && !_fr.vecs()[_dinfo.weightChunkId()].isConst() //if all constant weights (such as 1) -> doesn't count as obs weights
            && !(_fr.vecs()[_dinfo.weightChunkId()].isBinary()); //special case for cross-val      -> doesn't count as obs weights
    final double global_weight_sum = obs_weights ? Math.round(_fr.vecs()[_dinfo.weightChunkId()].mean() * _fr.numRows()) : 0;

    DataInfo.Row row = null;
    DataInfo.Row[] rows = null;
    if (_sparse)
      rows = _dinfo.extractSparseRows(chunks);
    else
      row = _dinfo.newDenseRow();
    double[] weight_map = null;
    double relative_chunk_weight = 1;
    //TODO: store node-local helper arrays in _dinfo -> avoid re-allocation and construction
    if (obs_weights) {
      weight_map = new double[nrows];
      double weight_sum = 0;
      for (int i = 0; i < nrows; ++i) {
        row = _sparse ? rows[i] : _dinfo.extractDenseRow(chunks, i, row);
        weight_sum += row.weight;
        weight_map[i] = weight_sum;
        assert (i == 0 || row.weight == 0 || weight_map[i] > weight_map[i - 1]);
      }
      if (weight_sum > 0) {
        ArrayUtils.div(weight_map, weight_sum); //normalize to 0...1
        relative_chunk_weight = global_weight_sum * nrows / _fr.numRows() / weight_sum;
      } else return; //nothing to do here - all rows have 0 weight
    }

    //Example:
    // _useFraction = 0.8 -> 1 repeat with fraction = 0.8
    // _useFraction = 1.0 -> 1 repeat with fraction = 1.0
    // _useFraction = 1.1 -> 2 repeats with fraction = 0.55
    // _useFraction = 2.1 -> 3 repeats with fraction = 0.7
    // _useFraction = 3.0 -> 3 repeats with fraction = 1.0
    final int repeats = (int) Math.ceil(_useFraction * relative_chunk_weight);
    final float fraction = (float) (_useFraction * relative_chunk_weight) / repeats;
    assert (fraction <= 1.0);

    final boolean sample = (fraction < 0.999 || obs_weights || _shuffle);
    final long chunkSeed = (0x8734093502429734L + _seed + offset) * (_iteration + 0x9823423497823423L);
    final Random skip_rng = sample ? RandomUtils.getRNG(chunkSeed) : null;
    int[] shufIdx = skip_rng == null ? null : new int[nrows];
    if (skip_rng != null) {
      for (int i = 0; i < nrows; ++i) shufIdx[i] = i;
      ArrayUtils.shuffleArray(shufIdx, skip_rng);
    }

    double[] responses = new double[getMiniBatchSize()];
    double[] offsets   = new double[getMiniBatchSize()];

    long seed = 0;
    final int miniBatchSize = getMiniBatchSize();
    long num_processed_rows = 0;
    long num_skipped_rows = 0;
    int miniBatchCounter = 0;
    for(int rep = 0; rep < repeats; ++rep) {
      for(int row_idx = 0; row_idx < nrows; ++row_idx){
        int r = sample ? -1 : 0;
        // only train with a given number of training samples (fraction*nrows)
        if (sample && !obs_weights && skip_rng.nextDouble() > fraction) continue;
        if (obs_weights && num_processed_rows % 2 == 0) { //every second row is randomly sampled -> that way we won't "forget" rare rows
          // importance sampling based on inverse of cumulative distribution
          double key = skip_rng.nextDouble();
          r = Arrays.binarySearch(weight_map, 0, nrows, key);
//          Log.info(Arrays.toString(weight_map));
//          Log.info("key: " + key + " idx: " + (r >= 0 ? r : (-r-1)));
          if (r<0) r=-r-1;
          assert(r == 0 || weight_map[r] > weight_map[r-1]);
        } else if (r == -1){
          r = shufIdx[row_idx];
          // if we have weights, and we did the %2 skipping above, then we need to find an alternate row with non-zero weight
          while (obs_weights && ((r == 0 && weight_map[r] == 0) || (r > 0 && weight_map[r] == weight_map[r-1]))) {
            r = skip_rng.nextInt(nrows); //random sampling with replacement
          }
        } else {
          assert(!obs_weights);
          r = row_idx; //linear scan - slightly faster
        }
        assert(r >= 0 && r<=nrows);

        seed = offset + rep * nrows + r;
        if (skipRow(seed)) {
          num_skipped_rows++;
          continue;
        }
        row = _sparse ? rows[r] : _dinfo.extractDenseRow(chunks, r, row);
        if(row.isBad() || row.weight == 0) {
          num_skipped_rows++;
          continue;
        } else {
          assert(row.weight > 0); //check that we never process a row that was held out via row.weight = 0
          if (outputs != null && outputs.length > 0) {
            assert(miniBatchSize==0);
            processRow(seed, row, outputs);
          }
          else {
            if (miniBatchSize > 0) { //DL
              processRow(seed, row, miniBatchCounter);
              responses[miniBatchCounter] = row.response != null && row.response.length > 0 ? row.response(0) : 0 /*autoencoder dummy*/;
              offsets[miniBatchCounter] = row.offset;
              miniBatchCounter++;
            }
            else //all other algos
              processRow(seed, row);
          }
        }
        num_processed_rows++;
        if (miniBatchCounter > 0 && miniBatchCounter % miniBatchSize == 0) {
          processMiniBatch(seed, responses, offsets, miniBatchCounter);
          miniBatchCounter = 0;
        }
      }
    }
    if (miniBatchCounter>0) {
      processMiniBatch(seed, responses, offsets, miniBatchCounter); //last bit
    }

    assert(fraction != 1 || num_processed_rows + num_skipped_rows == repeats * nrows);
    chunkDone(num_processed_rows);
  }

  public static class ExtractDenseRow extends MRTask {
    final private DataInfo _di; //INPUT
    final private long _gid; //INPUT
    public DataInfo.Row _row; //OUTPUT
    public ExtractDenseRow(DataInfo di, long globalRowId) { _di = di;  _gid = globalRowId; }

    @Override
    public void map(Chunk[] cs) {
      // fill up _row with the data of row with global id _gid
      long start = cs[0].start();
      if (start <= _gid && cs[0].start()+cs[0].len() > _gid) {
        _row = _di.newDenseRow();
        _di.extractDenseRow(cs, (int)(_gid-cs[0].start()), _row);
      }
    }

    @Override
    public void reduce(ExtractDenseRow mrt) {
      if (mrt._row != null) {
        assert(this._row == null); //only one thread actually filled the output _row
        _row = mrt._row;
      }
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy