All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.util.BoosterHelper Maven / Gradle / Ivy

package hex.tree.xgboost.util;

import ai.h2o.xgboost4j.java.*;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;

/**
 * Utility to access package private Booster methods.
 */
public class BoosterHelper {

  public static Booster loadModel(InputStream in) {
    try {
      return XGBoost.loadModel(in);
    } catch (XGBoostError | IOException e) {
      throw new IllegalStateException("Failed to load booster.", e);
    }
  }
  
  public static Booster loadModel(byte[] boosterBytes) {
    if (boosterBytes == null) {
      throw new IllegalArgumentException("Booster not initialized!");
    }
    return loadModel(new ByteArrayInputStream(boosterBytes));
  }

  /**
   * Invalidates XGBoost objects (Booster, DMatrix) and frees up their memory
   * @param xgbObjects list of XGBoost objects
   * @throws IllegalStateException when object invalidation fails, only the first exception will be reported (as the
   * exception cause), we assume the other ones will have a same reason
   */
  public static void dispose(Object... xgbObjects) throws IllegalStateException {
    Exception firstException = null;
    for (Object xgbObject : xgbObjects) {
      if (xgbObject == null)
        continue;
      if (xgbObject instanceof Booster) {
        try {
          ((Booster) xgbObject).dispose();
        } catch (Exception e) {
          if (firstException == null)
            firstException = e;
        }
      } else if (xgbObject instanceof DMatrix) {
        try {
          ((DMatrix) xgbObject).dispose();
        } catch (Exception e) {
          if (firstException == null)
            firstException = e;
        }
      } else
        assert false : "Unsupported XGBoost object type: " + xgbObject.getClass();
    }
    if (firstException != null) {
      throw new IllegalStateException("We were unable to free-up xgboost memory. " +
              "This could indicate a memory leak and it can lead to H2O instability.", firstException);
    }
  }

  public interface BoosterOp {
    X apply(Booster booster) throws XGBoostError;
  }

  public static  X doWithLocalRabit(BoosterOp op, Booster booster) throws XGBoostError {
    boolean shutdownRabit = true;
    try {
      Map rabitEnv = new HashMap<>();
      rabitEnv.put("DMLC_TASK_ID", "0");
      Rabit.init(rabitEnv);
      shutdownRabit = true;
      X result = op.apply(booster);
      Rabit.shutdown();
      shutdownRabit = false;
      return result;
    } finally {
      if (shutdownRabit)
        try {
          Rabit.shutdown();
        } catch (XGBoostError e) {
          e.printStackTrace(); // don't rely on logging support in genmodel
        }
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy