All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.kernel.KernelFunctionLoader Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 * Copyright 2015 Skymind,Inc.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.nd4j.linalg.jcublas.kernel;



import jcuda.utils.KernelLauncher;

import org.apache.commons.io.IOUtils;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;


/**
 * Kernel function loader:
 *
 * @author Adam Gibson
 */
public class KernelFunctionLoader {

    public final static String NAME_SPACE = "org.nd4j.linalg.jcublas";
    public final static String DOUBLE = NAME_SPACE + ".double.functions";
    public final static String FLOAT = NAME_SPACE + ".float.functions";
    public final static String IMPORTS_FLOAT = NAME_SPACE + ".float.imports";
    public final static String IMPORTS_DOUBLE = NAME_SPACE + ".double.imports";
    public final static String CACHE_COMPILED = NAME_SPACE + ".cache_compiled";
    private Map paths = new HashMap<>();
    private static Logger log = LoggerFactory.getLogger(KernelFunctionLoader.class);
    private static KernelFunctionLoader INSTANCE;
    private static Map launchers = new HashMap<>();
    private boolean init = false;

    private KernelFunctionLoader() {}

    /**
     * Singleton pattern
     *
     * @return
     */
    public static synchronized KernelFunctionLoader getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new KernelFunctionLoader();
            Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
                @Override
                public void run() {
                    INSTANCE.unload();
                }
            }));
            try {
                INSTANCE.load();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }


        return INSTANCE;
    }

    private static String dataFolder(DataBuffer.Type type) {
        return "/kernels/" + (type == DataBuffer.Type.FLOAT ? "float" : "double");
    }


    public  static KernelLauncher launcher(String functionName,String dataType) {
        return KernelFunctionLoader.getInstance().get(functionName,dataType);
    }
    

    public KernelLauncher get(String functionName,String dataType) {
        String name = functionName + "_" + dataType;
        
        KernelLauncher launcher = launchers.get(name);
        if(launcher == null) {
            throw new IllegalArgumentException("Unable to find " + name);
        }
        return launcher;
    }


    /**
     * Clean up all the modules
     */
    public void unload() {
        init = false;
    }



    /**
     * Load the appropriate functions from the class
     * path in to one module
     *
     * @return the module associated with this
     * @throws Exception
     */
    public void load() throws Exception {
        if (init)
            return;
        StringBuffer sb = new StringBuffer();
        sb.append("nvcc -g -G -ptx");

        ClassPathResource res = new ClassPathResource("/cudafunctions.properties");
        if (!res.exists())
            throw new IllegalStateException("Please put a cudafunctions.properties in your class path");
        Properties props = new Properties();
        props.load(res.getInputStream());
        log.info("Registering cuda functions...");
        //ensure imports for each file before compiling
        ensureImports(props, "float");
        ensureImports(props, "double");
        compileAndLoad(props, FLOAT, "float");
        compileAndLoad(props, DOUBLE, "double");

        init = true;
    }

    /**
     * The extension of the given file name is replaced with "ptx".
     * If the file with the resulting name does not exist, it is
     * compiled from the given file using NVCC. The name of the
     * PTX file is returned.
     * 

*

* Note that you may run in to an error akin to: * Unsupported GCC version *

* At your own risk, comment the lines under: * /usr/local/cuda-$VERSION/include/host_config.h *

* #if defined(__GNUC__) *

* if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 8) * #error -- unsupported GNU version! gcc 4.9 and up are not supported! *

* #endif /* __GNUC__> 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 8) *

* #endif __GNUC__ *

* This will allow you to bypass the compiler restrictions. Again, do so at your own risk. * * @return The name of the PTX file * @throws java.io.IOException If an I/O error occurs */ private void compileAndLoad(Properties props, String key, String dataType) throws IOException { compileAndLoad(props, key, dataType,0); } /** * The extension of the given file name is replaced with "ptx". * If the file with the resulting name does not exist, it is * compiled from the given file using NVCC. The name of the * PTX file is returned. *

*

* Note that you may run in to an error akin to: * Unsupported GCC version *

* At your own risk, comment the lines under: * /usr/local/cuda-$VERSION/include/host_config.h *

* #if defined(__GNUC__) *

* if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 8) * #error -- unsupported GNU version! gcc 4.9 and up are not supported! *

* #endif /* __GNUC__> 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 8) *

* #endif __GNUC__ *

* This will allow you to bypass the compiler restrictions. Again, do so at your own risk. * * @return The name of the PTX file * @throws java.io.IOException If an I/O error occurs */ private void compileAndLoad(Properties props, String key, String dataType,int compiledAttempts) throws IOException { String f = props.getProperty(key); StringBuffer sb = new StringBuffer(); sb.append("nvcc -g -G "); sb.append(" --include-path "); String tmpDir = System.getProperty("java.io.tmpdir"); StringBuffer dir = new StringBuffer(); boolean cache = Boolean.parseBoolean(props.getProperty(CACHE_COMPILED, String.valueOf("true"))); sb.append(tmpDir) .append(File.separator) .append("kernels") .append(File.separator).append(dataType).append(File.separator) .toString(); String kernelPath = dir.append(tmpDir) .append(File.separator) .append("kernels") .append(File.separator).append(dataType).append(File.separator) .toString(); boolean shouldCompile = cache; if(cache) { File tmpDir2 = new File(tmpDir + File.separator + "kernels"); if(tmpDir2.exists()) { shouldCompile = cache && !tmpDir2.exists() || compiledAttempts > 0; } } String[] split = f.split(","); if(shouldCompile) { sb.append(" ").append(" -ptx "); log.info("Loading " + dataType + " cuda functions"); if (f != null) { for (String s : split) { String loaded = extract("/kernels/" + dataType + "/" + s + ".cu", dataType.equals("float") ? DataBuffer.Type.FLOAT : DataBuffer.Type.DOUBLE); sb.append(" " + loaded); } sb.append(" --output-directory " + kernelPath); Process process = Runtime.getRuntime().exec(sb.toString()); String errorMessage = new String(IOUtils.toByteArray(process.getErrorStream())); String outputMessage = new String(IOUtils.toByteArray(process.getInputStream())); int exitValue = 0; try { exitValue = process.waitFor(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException( "Interrupted while waiting for nvcc output", e); } if (exitValue != 0) { log.info("nvcc process exitValue " + exitValue); log.info("errorMessage:\n" + errorMessage); log.info("outputMessage:\n" + outputMessage); throw new IOException( "Could not create .ptx file: " + errorMessage); } } } else { log.info("Modules appear to already be compiled..attempting to use cache"); for (String module : split) { log.info("Is cached: " + module); String path = kernelPath + module + ".ptx"; String name = module + "_" + dataType; paths.put(name,path); } } try { for (String module : split) { log.info("Loading " + module); String path = kernelPath + module + ".ptx"; String name = module + "_" + dataType; paths.put(name,path); KernelLauncher launch = KernelLauncher.load(path, name); launchers.put(name, launch); } }catch (Exception e) { if(!shouldCompile && compiledAttempts < 3) { log.warn("Error loading modules...attempting recompile"); props.setProperty(CACHE_COMPILED,String.valueOf(true)); compileAndLoad(props, key, dataType,compiledAttempts + 1); } else throw new RuntimeException(e); } } //extract the source file /** * Extract the resource ot the specified * absolute path * * @param file the file * @param dataType the data type to extract for * @return * @throws IOException */ public String extract(String file, DataBuffer.Type dataType) throws IOException { String path = dataFolder(dataType); String tmpDir = System.getProperty("java.io.tmpdir"); File dataDir = new File(tmpDir, path); if (!dataDir.exists()) dataDir.mkdirs(); return loadFile(file); } private void ensureImports(Properties props, String dataType) throws IOException { if (dataType.equals("float")) { String[] imports = props.getProperty(IMPORTS_FLOAT).split(","); for (String import1 : imports) { loadFile("/kernels/" + dataType + "/" + import1); } } else { String[] imports = props.getProperty(IMPORTS_DOUBLE).split(","); for (String import1 : imports) { loadFile("/kernels/" + dataType + "/" + import1); } } } private String loadFile(String file) throws IOException { ClassPathResource resource = new ClassPathResource(file); String tmpDir = System.getProperty("java.io.tmpdir"); if (!resource.exists()) throw new IllegalStateException("Unable to find file " + resource); File out = new File(tmpDir, file); if (!out.getParentFile().exists()) out.getParentFile().mkdirs(); if (out.exists()) out.delete(); out.createNewFile(); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(out)); IOUtils.copy(resource.getInputStream(), bos); bos.flush(); bos.close(); out.deleteOnExit(); return out.getAbsolutePath(); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy