All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.util.GpuUtils Maven / Gradle / Ivy

package hex.tree.xgboost.util;

import ai.h2o.xgboost4j.java.*;
import org.apache.log4j.Logger;
import water.DTask;
import water.H2O;
import water.H2ONode;
import water.RPC;

import java.io.IOException;
import java.util.*;

public class GpuUtils {

    private static final Logger LOG = Logger.getLogger(GpuUtils.class);
    
    public static final int[] DEFAULT_GPU_ID = new int[] { 0 };

    private static volatile boolean defaultGpuIdNotValid = false;
    private static volatile boolean gpuSearchPerformed = false;

    private static final Set GPUS = new HashSet<>();

    static boolean isGpuSupportEnabled() {
        try {
            INativeLibLoader loader = NativeLibLoader.getLoader();
            if (!(loader instanceof NativeLibraryLoaderChain))
                return false;
            NativeLibraryLoaderChain chainLoader = (NativeLibraryLoaderChain) loader;
            NativeLibrary lib = chainLoader.getLoadedLibrary();
            return lib.hasCompilationFlag(NativeLibrary.CompilationFlags.WITH_GPU);
        } catch (IOException e) {
            LOG.debug(e);
            return false;
        }
    }

    private static boolean gpuCheckEnabled() {
        return H2O.getSysBoolProperty("xgboost.gpu.check.enabled", true);
    }

    public static int numGPUs(H2ONode node) {
        return allGPUs(node).size();
    }

    public static Set allGPUs(H2ONode node) {
        if (H2O.SELF.equals(node)) {
            return allGPUs();
        } else {
            AllGPUsTask t = new AllGPUsTask();
            new RPC<>(node, t).call().get();
            return new HashSet<>(Arrays.asList(t.gpuIds));
        }
    }

    private static class AllGPUsTask extends DTask {
        // OUT
        private Integer[] gpuIds;

        private AllGPUsTask() {
        }

        @Override
        public void compute2() {
            gpuIds = allGPUs().toArray(new Integer[0]);
            tryComplete();
        }
    }

    public static Set allGPUs() {
        if (gpuSearchPerformed) return Collections.unmodifiableSet(GPUS);
        int nextGpuId = 0;
        while (hasGPU(new int[] { nextGpuId })) {
            nextGpuId++;
        }
        gpuSearchPerformed = true;
        return Collections.unmodifiableSet(GPUS);
    }

    public static boolean hasGPU(H2ONode node, int[] gpu_id) {
        final boolean hasGPU;
        if (H2O.SELF.equals(node)) {
            hasGPU = hasGPU(gpu_id);
        } else {
            HasGPUTask t = new HasGPUTask(gpu_id);
            new RPC<>(node, t).call().get();
            hasGPU = t._hasGPU;
        }
        LOG.debug("Availability of GPU (id=" + Arrays.toString(gpu_id) + ") on node " + node + ": " + hasGPU);
        return hasGPU;
    }

    private static class HasGPUTask extends DTask {
        private final int[] _gpu_id;
        // OUT
        private boolean _hasGPU;

        private HasGPUTask(int[] gpu_id) {
            _gpu_id = gpu_id;
        }

        @Override
        public void compute2() {
            _hasGPU = hasGPU(_gpu_id);
            tryComplete();
        }
    }

    public static boolean hasGPU(int[] gpu_id) {
        if (!gpuCheckEnabled()) {
            return true;
        }
        if (gpu_id == null && defaultGpuIdNotValid) // quick default path & no synchronization - if we already know we don't have the default GPU, let's not to find out again
            return false;
        boolean hasGPU = true;
        if (gpu_id == null) gpu_id = DEFAULT_GPU_ID;
        for (int i = 0; hasGPU && i < gpu_id.length; i++) {
            hasGPU = hasGPU_impl(gpu_id[i]);
        }
        if (Arrays.equals(gpu_id, DEFAULT_GPU_ID) && !hasGPU) {
            defaultGpuIdNotValid = true; // this can never change back
        }
        return hasGPU;
    }

    public static boolean hasGPU() {
        return hasGPU(null);
    }

    // helper
    private static synchronized boolean hasGPU_impl(int gpu_id) {
        if (!isGpuSupportEnabled()) {
            return false;
        }

        if (GPUS.contains(gpu_id)) {
            return true;
        }

        DMatrix trainMat;
        try {
            trainMat = new DMatrix(new float[]{1, 2, 1, 2}, 2, 2);
            trainMat.setLabel(new float[]{1, 0});
        } catch (XGBoostError xgBoostError) {
            throw new IllegalStateException("Couldn't prepare training matrix for XGBoost.", xgBoostError);
        }

        HashMap params = new HashMap<>();
        params.put("tree_method", "gpu_hist");
        params.put("silent", 1);
        params.put("fail_on_invalid_gpu_id", true);

        params.put("gpu_id", gpu_id);
        HashMap watches = new HashMap<>();
        watches.put("train", trainMat);
        try {
            Map localRabitEnv = new HashMap<>();
            Rabit.init(localRabitEnv);
            ai.h2o.xgboost4j.java.XGBoost.train(trainMat, params, 1, watches, null, null);
            GPUS.add(gpu_id);
            return true;
        } catch (XGBoostError xgBoostError) {
            return false;
        } finally {
            try {
                Rabit.shutdown();
            } catch (XGBoostError e) {
                LOG.warn("Cannot shutdown XGBoost Rabit for current thread.");
            }
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy