All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.finmath.jcuda.JCudaUtils Maven / Gradle / Ivy

package net.finmath.jcuda;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Paths;
import java.util.Arrays;

/**
 * Adapted from JCuda examples: Reads a CUDA file, compiles it to a PTX file
 * using NVCC, loads the PTX file as a module and executes
 * the kernel function.
 */
public class JCudaUtils
{
	/**
	 * The extension of the given file name is replaced with "ptx".
	 * If the file with the resulting name does not exist, it is
	 * compiled from the given file using NVCC. The name of the
	 * PTX file is returned.
	 *
	 * @param cuFileURL The name of the .CU file
	 * @return The name of the PTX file
	 * @throws IOException If an I/O error occurs
	 * @throws URISyntaxException
	 */
	public static String preparePtxFile(URL cuFileURL) throws IOException, URISyntaxException
	{
		String cuFileName = Paths.get(cuFileURL.toURI()).toFile().getAbsolutePath();
		int endIndex = cuFileName.lastIndexOf('.');
		if (endIndex == -1)
		{
			endIndex = cuFileName.length()-1;
		}
		String ptxFileName = cuFileName.substring(0, endIndex+1)+"ptx";
		File ptxFile = new File(ptxFileName);
		if (ptxFile.exists())
			return ptxFileName;

		File cuFile = new File(cuFileName);
		if (!cuFile.exists())
			throw new IOException("Input file not found: "+cuFileName);

		/*
		 * Check for 64 bit or 32 bit
		 */
		String modelString = "-m"+System.getProperty("sun.arch.data.model");

		String[] command = {
				"nvcc",
				"-arch",
				"sm_30",
				"-fmad",
				"false",
				modelString,
				"-ptx",
				cuFile.getPath(),
				"-o",
				ptxFileName };

		//		String command = "nvcc " + modelString + " -ptx " + "" + cuFile.getPath() + " -o " + ptxFileName;

		System.out.println("Executing\n"+Arrays.toString(command));
		Process process = Runtime.getRuntime().exec(command);

		String errorMessage =
				new String(toByteArray(process.getErrorStream()));
		String outputMessage =
				new String(toByteArray(process.getInputStream()));
		int exitValue = 0;
		try
		{
			exitValue = process.waitFor();
		}
		catch (InterruptedException e)
		{
			Thread.currentThread().interrupt();
			throw new IOException(
					"Interrupted while waiting for nvcc output", e);
		}

		if (exitValue != 0)
		{
			System.out.println("nvcc process exitValue "+exitValue);
			System.out.println("errorMessage:\n"+errorMessage);
			System.out.println("outputMessage:\n"+outputMessage);
			throw new IOException(
					"Could not create .ptx file: "+errorMessage);
		}

		System.out.println("Finished creating PTX file");
		return ptxFileName;
	}

	/**
	 * Fully reads the given InputStream and returns it as a byte array
	 *
	 * @param inputStream The input stream to read
	 * @return The byte array containing the data from the input stream
	 * @throws IOException If an I/O error occurs
	 */
	private static byte[] toByteArray(InputStream inputStream)
			throws IOException
	{
		ByteArrayOutputStream baos = new ByteArrayOutputStream();
		byte buffer[] = new byte[8192];
		while (true)
		{
			int read = inputStream.read(buffer);
			if (read == -1)
			{
				break;
			}
			baos.write(buffer, 0, read);
		}
		return baos.toByteArray();
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy