All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.java.DataBatch Maven / Gradle / Ivy

The newest version!
package ml.dmlc.xgboost4j.java;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import ml.dmlc.xgboost4j.LabeledPoint;

/**
 * A mini-batch of data that can be converted to DMatrix.
 * The data is in sparse matrix CSR format.
 *
 * This class is used to support advanced creation of DMatrix from Iterator of DataBatch,
 */
class DataBatch {
  private static final Log logger = LogFactory.getLog(DataBatch.class);
  /** The offset of each rows in the sparse matrix */
  final long[] rowOffset;
  /** weight of each data point, can be null */
  final float[] weight;
  /** label of each data point, can be null */
  final float[] label;
  /** index of each feature(column) in the sparse matrix */
  final int[] featureIndex;
  /** value of each non-missing entry in the sparse matrix */
  final float[] featureValue ;

  DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex,
            float[] featureValue) {
    this.rowOffset = rowOffset;
    this.weight = weight;
    this.label = label;
    this.featureIndex = featureIndex;
    this.featureValue = featureValue;
  }

  static class BatchIterator implements Iterator {
    private final Iterator base;
    private final int batchSize;

    BatchIterator(Iterator base, int batchSize) {
      this.base = base;
      this.batchSize = batchSize;
    }

    @Override
    public boolean hasNext() {
      return base.hasNext();
    }

    @Override
    public DataBatch next() {
      try {
        int numRows = 0;
        int numElem = 0;
        List batch = new ArrayList<>(batchSize);
        while (base.hasNext() && batch.size() < batchSize) {
          LabeledPoint labeledPoint = base.next();
          batch.add(labeledPoint);
          numElem += labeledPoint.values().length;
          numRows++;
        }

        long[] rowOffset = new long[numRows + 1];
        float[] label = new float[numRows];
        int[] featureIndex = new int[numElem];
        float[] featureValue = new float[numElem];
        float[] weight = new float[numRows];

        int offset = 0;
        for (int i = 0; i < batch.size(); i++) {
          LabeledPoint labeledPoint = batch.get(i);
          rowOffset[i] = offset;
          label[i] = labeledPoint.label();
          weight[i] = labeledPoint.weight();
          if (labeledPoint.indices() != null) {
            System.arraycopy(labeledPoint.indices(), 0, featureIndex, offset,
                    labeledPoint.indices().length);
          } else {
            for (int j = 0; j < labeledPoint.values().length; j++) {
              featureIndex[offset + j] = j;
            }
          }

          System.arraycopy(labeledPoint.values(), 0, featureValue, offset,
                  labeledPoint.values().length);
          offset += labeledPoint.values().length;
        }

        rowOffset[batch.size()] = offset;
        return new DataBatch(rowOffset, weight, label, featureIndex, featureValue);
      } catch (RuntimeException runtimeError) {
        logger.error(runtimeError);
        return null;
      }
    }

    @Override
    public void remove() {
      throw new UnsupportedOperationException("DataBatch.BatchIterator.remove");
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy