All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.context.ContextHolder Maven / Gradle / Ivy

/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.jcublas.context;

import lombok.Data;
import org.nd4j.linalg.io.ClassPathResource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;


/**
 * A multithreaded version derived
 * from the cuda launcher util
 * by the authors of jcuda.
 * 

* This class handles managing cuda contexts * across multiple devices and threads. * * @author Adam Gibson */ @Data public class ContextHolder { private Map threadNameToDeviceNumber = new ConcurrentHashMap<>(); private Map threads = new ConcurrentHashMap<>(); private List bannedDevices; private int numDevices = 0; private static ContextHolder INSTANCE; public final static String DEVICES_TO_BAN = "org.nd4j.linalg.jcuda.jcublas.ban_devices"; private static AtomicBoolean deviceSetup = new AtomicBoolean(false); private boolean confCalled = false; private static Logger log = LoggerFactory.getLogger(ContextHolder.class); private AtomicBoolean shutdown = new AtomicBoolean(false); // holder for memory strategies override /** * Singleton pattern * * @return the instance for the context holder. */ public static synchronized ContextHolder getInstance() { if (INSTANCE == null) { Properties props = new Properties(); try { props.load(new ClassPathResource("/cudafunctions.properties", ContextHolder.class.getClassLoader()).getInputStream()); } catch (IOException e) { throw new RuntimeException(e); } INSTANCE = new ContextHolder(); INSTANCE.configure(); //set the properties to be accessible globally for (String pair : props.stringPropertyNames()) System.getProperties().put(pair, props.getProperty(pair)); } return INSTANCE; } public Map getThreads() { return threads; } /** * Get the number of devices * * @return the number of devices */ public int deviceNum() { return numDevices; } /** * Configure the given information * based on the device */ public void configure() { if (confCalled) return; /* setContext(); // Check if the device supports mapped host memory cudaDeviceProp deviceProperties = new cudaDeviceProp(); JCuda.cudaGetDeviceProperties(deviceProperties, 0); if (deviceProperties.canMapHostMemory == 0) { System.err.println("This device can not map host memory"); System.err.println(deviceProperties.toFormattedString()); return; } */ /* // if we'll need stack initialization, here's the code int numberOfCores = CudaArgs.convertMPtoCores(deviceProperties.major, deviceProperties.minor, deviceProperties.multiProcessorCount) * deviceProperties.multiProcessorCount; int maxThreadsPerCore = deviceProperties.maxThreadsPerMultiProcessor / CudaArgs.convertMPtoCores(deviceProperties.major, deviceProperties.minor, deviceProperties.multiProcessorCount); long stackSize = Math.min(512*1024, deviceProperties.totalGlobalMem / numberOfCores / (maxThreadsPerCore + 8) ); JCuda.cudaDeviceSetLimit(0,stackSize); */ //force certain ops to have a certain number of threads /* Properties threadProps = new Properties(); try { InputStream is = ContextHolder.class.getResourceAsStream("/function_threads.properties"); threadProps.load(is); } catch (IOException e) { e.printStackTrace(); } for(String prop : threadProps.stringPropertyNames()) { threads.put(prop,Integer.parseInt(threadProps.getProperty(prop))); } */ try { /* GenericObjectPoolConfig config = new GenericObjectPoolConfig(); config.setJmxEnabled(true); config.setBlockWhenExhausted(false); config.setMaxIdle(Runtime.getRuntime().availableProcessors()); config.setMaxTotal(Runtime.getRuntime().availableProcessors()); config.setMinIdle(Runtime.getRuntime().availableProcessors()); config.setJmxNameBase("handles"); handlePool = new CublasHandlePool(new CublasHandlePooledItemFactory(),config); GenericObjectPoolConfig confClone = config.clone(); confClone.setMaxTotal(Runtime.getRuntime().availableProcessors() * 10); confClone.setMaxIdle(Runtime.getRuntime().availableProcessors() * 10); confClone.setMinIdle(Runtime.getRuntime().availableProcessors() * 10); GenericObjectPoolConfig streamConf = confClone.clone(); streamConf.setJmxNameBase("streams"); streamPool = new StreamPool(new StreamItemFactory(),streamConf); GenericObjectPoolConfig oldStreamConf = streamConf.clone(); oldStreamConf.setJmxNameBase("oldstream"); oldStreamPool = new OldStreamPool(new OldStreamItemFactory(),oldStreamConf); setContext(); //seed with multiple streams to encourage parallelism for(int i = 0; i < Runtime.getRuntime().availableProcessors(); i++) { streamPool.addObject(); oldStreamPool.addObject(); } */ } catch (Exception e) { log.warn("Unable to initialize cuda", e); } /* for(int i = 0; i < numDevices; i++) { ClassPathResource confFile = new ClassPathResource("devices/" + i, ContextHolder.class.getClassLoader()); if(confFile.exists()) { Properties props2 = new Properties(); try { props2.load(confFile.getInputStream()); confs.put(i,new DeviceConfiguration(i,props2)); } catch (IOException e) { throw new RuntimeException(e); } } else confs.put(i,new DeviceConfiguration(i)); } */ confCalled = true; } public void setNumDevices(int numDevices) { this.numDevices = numDevices; } /** * Get the device number for a particular host thread * * @return the device for the given host thread */ public int getDeviceForThread() { /* if(numDevices > 1) { Integer device = threadNameToDeviceNumber.get(Thread.currentThread().getName()); if(device == null) { org.nd4j.linalg.api.rng.Random random = Nd4j.getRandom(); if(random == null) throw new IllegalStateException("Unable to load random class"); device = Nd4j.getRandom().nextInt(numDevices); //reroute banned devices while(bannedDevices != null && bannedDevices.contains(device)) device = Nd4j.getRandom().nextInt(numDevices); threadNameToDeviceNumber.put(Thread.currentThread().getName(),device); return device; } } */ return 0; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy