All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.java.DataBatch Maven / Gradle / Ivy

There is a newer version: 0.6.0-incubating
Show newest version
package ml.dmlc.xgboost4j.java;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import ml.dmlc.xgboost4j.LabeledPoint;

/**
 * A mini-batch of data that can be converted to DMatrix.
 * The data is in sparse matrix CSR format.
 *
 * This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
 */
class DataBatch {
  /** The offset of each rows in the sparse matrix */
  final long[] rowOffset;
  /** weight of each data point, can be null */
  final float[] weight;
  /** label of each data point, can be null */
  final float[] label;
  /** index of each feature(column) in the sparse matrix */
  final int[] featureIndex;
  /** value of each non-missing entry in the sparse matrix */
  final float[] featureValue ;

  DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
            float[] featureValue) {
    this.rowOffset = rowOffset;
    this.weight = weight;
    this.label = label;
    this.featureIndex = featureIndex;
    this.featureValue = featureValue;
  }

  static class BatchIterator implements Iterator {
    private final Iterator base;
    private final int batchSize;

    BatchIterator(Iterator base, int batchSize) {
      this.base = base;
      this.batchSize = batchSize;
    }

    @Override
    public boolean hasNext() {
      return base.hasNext();
    }

    @Override
    public DataBatch next() {
      int numRows = 0;
      int numElem = 0;
      List batch = new ArrayList<>(batchSize);
      while (base.hasNext() && batch.size() < batchSize) {
        LabeledPoint labeledPoint = base.next();
        batch.add(labeledPoint);
        numElem += labeledPoint.values().length;
        numRows++;
      }

      long[] rowOffset = new long[numRows + 1];
      float[] label = new float[numRows];
      int[] featureIndex = new int[numElem];
      float[] featureValue = new float[numElem];
      float[] weight = new float[numRows];

      int offset = 0;
      for (int i = 0; i < batch.size(); i++) {
        LabeledPoint labeledPoint = batch.get(i);
        rowOffset[i] = offset;
        label[i] = labeledPoint.label();
        weight[i] = labeledPoint.weight();
        if (labeledPoint.indices() != null) {
          System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
                  labeledPoint.indices().length);
        } else {
          for (int j = 0; j < labeledPoint.values().length; j++) {
            featureIndex[offset + j] = j;
          }
        }

        System.arraycopy(labeledPoint.values(), 0, featureValue, offset,
                labeledPoint.values().length);
        offset += labeledPoint.values().length;
      }

      rowOffset[batch.size()] = offset;
      return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
    }

    @Override
    public void remove() {
      throw new UnsupportedOperationException("DataBatch.BatchIterator.remove");
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy