All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.deepwater.DeepWaterDatasetIterator Maven / Gradle / Ivy

package hex.deepwater;

import hex.DataInfo;
import water.*;
import water.fvec.C4FChunk;
import water.fvec.C8DChunk;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.UnsafeUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;


class DeepWaterDatasetIterator extends DeepWaterIterator {

  DeepWaterDatasetIterator(ArrayList rows, ArrayList labels, DataInfo dinfo, int batch_size, boolean cache) throws IOException {
    super(batch_size, dinfo.fullN(), cache);
    _rows_lst = rows;
    _label_lst = labels;
    _dinfo = dinfo;
  }

  // Row-based storage, dense - direct mapping to float[]
  static class IcedRow extends Iced {
    int size() { return _data.length; }
    float getVal(int i) { return _data[i]; }
    void insertValuesIntoArray(float[] vals, int offset) {
      for (int i=0; i<_data.length; ++i) {
        vals[offset+i] = _data[i];
      }
    }
    public IcedRow() {}
    IcedRow(float[] fs) {
      _data = fs;
    }
    private float[] _data;
  }

  static class FrameDataConverter extends H2O.H2OCountedCompleter {
    int _index;
    int _globalIndex;
    DataInfo _dinfo;
    float _label;
    float[] _destData;
    float[] _destLabel;
    boolean _cache;
    FrameDataConverter(int index, int globalIndex, DataInfo dinfo, float label, float[] destData, float[] destLabel, boolean cache) {
      _index=index;
      _globalIndex=globalIndex;
      _dinfo = dinfo;
      _label = label;
      _destData=destData;
      _destLabel=destLabel;
      _cache = cache;
    }

    @Override
    public void compute2() {
      _destLabel[_index] = _label;
      final int start=_index*_dinfo.fullN();
      Key rowKey = Key.make(_dinfo._adaptedFrame._key + "_" + _dinfo.fullN() + "_row_" + Integer.toString(_globalIndex) + "_" + DeepWaterModel.CACHE_MARKER);
      boolean status = false;
      if (_cache) {
        IcedRow icedRow = DKV.getGet(rowKey);
        if (icedRow != null) {
          icedRow.insertValuesIntoArray(_destData, start);
          status = true;
        }
      }
      if (!status) { //only do this the first time
        DataInfo.Row row = _dinfo.newDenseRow();
        Chunk[] chks = new Chunk[_dinfo._adaptedFrame.numCols()];
        for (int i=0;i " + Arrays.toString(_destData));
//        System.err.println(Arrays.toString(Arrays.copyOfRange(_destData, start, start + _dinfo.fullN())));
        if (_cache) {
          Value v = new Value(rowKey, new IcedRow(Arrays.copyOfRange(_destData, start, start + _dinfo.fullN())));
          DKV.put(rowKey, v);
          v.freeMem();
        }
      }
      tryComplete();
    }
  }

  public boolean Next(Futures fs) throws IOException {
    if (_start_index < _rows_lst.size()) {
      if (_start_index + _batch_size > _rows_lst.size())
        _start_index = _rows_lst.size() - _batch_size;
      // Multi-Threaded data preparation
      for (int i = 0; i < _batch_size; i++)
        fs.add(H2O.submitTask(new FrameDataConverter(i, _rows_lst.get(_start_index+i), _dinfo, _label_lst==null?-1:_label_lst.get(_start_index + i), _data[which()], _label[which()], _cache)));
      fs.blockForPending();
      flip();
      _start_index += _batch_size;
      return true;
    } else {
      return false;
    }
  }

  final private ArrayList _rows_lst;
  final private ArrayList _label_lst;
  final private DataInfo _dinfo;
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy