All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.XGBoostExtension Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.xgboost;

import ai.h2o.xgboost4j.java.INativeLibLoader;
import ai.h2o.xgboost4j.java.NativeLibLoader;
import hex.tree.xgboost.util.NativeLibrary;
import hex.tree.xgboost.util.NativeLibraryLoaderChain;
import org.apache.log4j.Logger;
import water.AbstractH2OExtension;
import water.H2O;

import java.io.IOException;
import java.util.Arrays;

/**
 * XGBoost Extension
 *
 * This is responsible for early initialization of
 * XGBoost per cluster node. The registration
 * of XGBoost REST API requires thix extension
 * to be enabled.
 */
public class XGBoostExtension extends AbstractH2OExtension {
  
  private static final Logger LOG = Logger.getLogger(XGBoostExtension.class);

  private static final String XGBOOST_MIN_REQUIREMENTS =
          "Xgboost (enabled GPUs) needs: \n"
                  + "  - CUDA 8.0\n"
                  + "XGboost (minimal version) needs: \n"
                  + "  - GCC 4.7+\n"
                  + "Apple silicon is not supported yet\n"
                  + "For more details, run in debug mode: `java -Dlog4j.configuration=file:///tmp/log4j.properties -jar h2o.jar`\n";

  // XGBoost initialization sequence was called flag
  private boolean isInitCalled = false;
  // XGBoost binary presence on the system
  private boolean isXgboostPresent = false;
  private NativeLibInfo nativeLibInfo = null;
  
  public static String NAME = "XGBoost";

  @Override
  public String getExtensionName() {
    return NAME;
  }

  @Override
  public boolean isEnabled() {
    // Check if some native library was loaded
    if (!isInitCalled) {
      synchronized (this) {
        if (!isInitCalled) {
          isXgboostPresent = initXgboost();
          isInitCalled = true;
        }
      }
    }
    return isXgboostPresent;
  }

  public void logNativeLibInfo() {
    if (nativeLibInfo == null) {
      LOG.warn("No native XGBoost library found.");
      return;
    }
    LOG.info("Found XGBoost backend with library: " + nativeLibInfo.name);
    if (nativeLibInfo.flags.length == 0) {
      LOG.warn("Your system supports only minimal version of XGBoost (no GPUs, no multithreading)!");
    } else {
      LOG.info("XGBoost supported backends: " + Arrays.toString(nativeLibInfo.flags));
    }
  }

  public static NativeLibraryLoaderChain getLoader() throws IOException {
    INativeLibLoader loader = NativeLibLoader.getLoader();
    if (! (loader instanceof NativeLibraryLoaderChain)) {
      LOG.warn("Unexpected XGBoost library loader found. Custom loaders are not supported in this version. " +
              "XGBoost extension will be disabled.");
      return null;
    }
    return(NativeLibraryLoaderChain) loader;
  }

  @Override
  public void onLocalNodeStarted() {
    if (!isEnabled())
      return;
    final double ratio = H2O.ARGS.off_heap_memory_ratio;
    if (H2O.ARGS.off_heap_memory_ratio > 0) {
      MemoryCheck.Report report = MemoryCheck.runCheck(ratio);
      if (!report.isOffHeapRequirementMet()) {
        LOG.warn("There doesn't seem to be enough memory available for XGBoost model training (off_heap_memory_ratio=" + ratio + "), " +
                "training XGBoost models is not advised. Details: " + report);
      }
    }
  }

  private boolean initXgboost() {
    try {
      NativeLibraryLoaderChain chainLoader = getLoader();
      if (chainLoader == null)
        return false;
      NativeLibrary lib = chainLoader.getLoadedLibrary();
      nativeLibInfo = new NativeLibInfo(lib);
      return true;
    } catch (IOException e) {
      // Ups no lib loaded or load failed
      LOG.debug("Cause of the xgboost unsuccessful load", e);
      LOG.warn("Cannot initialize XGBoost backend! " + XGBOOST_MIN_REQUIREMENTS);
      return false;
    }
  }

  private static class NativeLibInfo {
    String name;
    NativeLibrary.CompilationFlags[] flags;

    private NativeLibInfo(NativeLibrary nl) {
      name = nl.getName();
      flags = nl.getCompilationFlags();
    }
  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy