All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.task.XGBoostSetupTask Maven / Gradle / Ivy

package hex.tree.xgboost.task;

import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.matrix.MatrixLoader;
import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.XGBoostError;
import org.apache.log4j.Logger;
import water.H2O;
import water.Key;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.IcedHashMapGeneric;

import java.io.File;
import java.util.Map;

/**
 * Initializes XGBoost training (converts Frame to set of node-local DMatrices)
 */
public class XGBoostSetupTask extends AbstractXGBoostTask {

  private static final Logger LOG = Logger.getLogger(XGBoostSetupTask.class);

  private final BoosterParms _boosterParms;
  private final byte[] _checkpoint;
  private final IcedHashMapGeneric.IcedHashMapStringString _rabitEnv;
  private final MatrixLoader _matrixLoader;
  private final String _saveMatrixDirectory;

  public XGBoostSetupTask(
      Key modelKey, String saveMatrixDirectory, BoosterParms boosterParms,
      byte[] checkpointToResume, Map rabitEnv, boolean[] nodes,
      MatrixLoader matrixLoader
  ) {
    super(modelKey, nodes);
    _boosterParms = boosterParms;
    _checkpoint = checkpointToResume;
    _matrixLoader = matrixLoader;
    _saveMatrixDirectory = saveMatrixDirectory;
    (_rabitEnv = new IcedHashMapGeneric.IcedHashMapStringString()).putAll(rabitEnv);
  }

  @Override
  protected void execute() {
    DMatrix trainMatrix, validMatrix = null;
    try {
      trainMatrix = _matrixLoader.makeLocalTrainMatrix().get();
    } catch (XGBoostError e) {
      throw new IllegalStateException("Failed to create XGBoost DMatrix for training dataset", e);
    }
    if (_matrixLoader.hasValidationFrame()) {
      try {
        validMatrix = _matrixLoader.makeLocalValidMatrix().get();
      } catch (XGBoostError e) {
        throw new IllegalStateException("Failed to create XGBoost DMatrix for validation dataset", e);
      }
    }
    if (_saveMatrixDirectory != null) {
      File directory = new File(_saveMatrixDirectory);
      if (directory.mkdirs()) {
        LOG.debug("Created directory for matrix export: " + directory.getAbsolutePath());
      }
      File trainPath = new File(directory, "train_matrix.part" + H2O.SELF.index());
      LOG.info("Saving node-local portion of XGBoost training dataset to " + trainPath.getAbsolutePath() + ".");
      trainMatrix.saveBinary(trainPath.getAbsolutePath());
      if (validMatrix != null) {
        File validPath = new File(directory, "valid_matrix.part" + H2O.SELF.index());
        LOG.info("Saving node-local portion of XGBoost validation dataset to " + validPath.getAbsolutePath() + ".");
        validMatrix.saveBinary(validPath.getAbsolutePath());
      }
    }
    _rabitEnv.put("DMLC_TASK_ID", String.valueOf(H2O.SELF.index()));

    XGBoostUpdater thread = XGBoostUpdater.make(_modelKey, trainMatrix, validMatrix, _boosterParms, _checkpoint, _rabitEnv);
    thread.start(); // we do not need to wait for the Updater to init Rabit - subsequent tasks will wait
  }

  /**
   * Finds what nodes actually do carry some of data of a given Frame
   * @param fr frame to find nodes for
   * @return FrameNodes
   */
  public static FrameNodes findFrameNodes(Frame fr) {
    // Count on how many nodes the data resides
    boolean[] nodesHoldingFrame = new boolean[H2O.CLOUD.size()];
    Vec vec = fr.anyVec();
    for(int chunkNr = 0; chunkNr < vec.nChunks(); chunkNr++) {
      int home = vec.chunkKey(chunkNr).home_node().index();
      if (! nodesHoldingFrame[home])
        nodesHoldingFrame[home] = true;
    }
    return new FrameNodes(fr, nodesHoldingFrame);
  }

  public static class FrameNodes {
    public final Frame _fr;
    public final boolean[] _nodes;
    public final int _numNodes;
    private FrameNodes(Frame fr, boolean[] nodes) {
      _fr = fr;
      _nodes = nodes;
      int n = 0;
      for (boolean f : _nodes)
        if (f) n++;
      _numNodes = n;
    }
    public int getNumNodes() { return _numNodes; }

    public boolean isSubsetOf(FrameNodes otherNodes) {
      for (int i = 0; i < _nodes.length; i++) {
        if (_nodes[i] && !otherNodes._nodes[i]) {
          return false;
        }
      }
      return true;
    }

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy