All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.cuda.JCurandHelper Maven / Gradle / Ivy

package com.expleague.ml.cuda;

import org.jetbrains.annotations.NotNull;

import jcuda.driver.CUdeviceptr;
import jcuda.jcurand.JCurand;
import jcuda.jcurand.curandGenerator;
import jcuda.jcurand.curandRngType;

/**
 * Project jmll
 *
 * @author Ksen
 */
public class JCurandHelper {

  static {
    JCudaHelper.hook();
  }

  public static curandGenerator createDefault() {
    return createGenerator(curandRngType.CURAND_RNG_PSEUDO_DEFAULT);
  }

  public static curandGenerator createGenerator(final int type) {
    final curandGenerator generator = new curandGenerator();

    JCurand.curandCreateGenerator(generator, type);
    JCurand.curandSetPseudoRandomGeneratorSeed(generator, System.nanoTime());

    return generator;
  }

  public static void destroyGenerator(final @NotNull curandGenerator generator) {
    JCurand.curandDestroyGenerator(generator);
  }

  public static CUdeviceptr generateUniform(final int length, final @NotNull curandGenerator generator) {
    final CUdeviceptr devicePointer = JCudaMemory.allocFloat(length);

    return generateUniform(devicePointer, length, generator);
  }

  public static CUdeviceptr generateUniform(
      final @NotNull CUdeviceptr devicePointer,
      final long length,
      final @NotNull curandGenerator generator
  ) {
    JCurand.curandGenerateUniform(generator, devicePointer, length);

    return devicePointer;
  }

  public static float[] generateUniformHost(final int size, final @NotNull curandGenerator generator) {
    final CUdeviceptr devicePointer = JCudaMemory.allocFloat(size);

    JCurand.curandGenerateUniform(generator, devicePointer, size);

    return JCudaMemory.copyFloatsDestr(devicePointer, size);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy