All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.cuda.data.impl.FloatVector Maven / Gradle / Ivy

package com.expleague.ml.cuda.data.impl;

import org.jetbrains.annotations.NotNull;

import com.expleague.ml.cuda.JCudaMemory;
import jcuda.driver.CUdeviceptr;

import com.expleague.ml.cuda.data.ArrayBased;

/**
 * Project jmll
 *
 * @author Ksen
 */
public class FloatVector implements ArrayBased {

  public int length;

  public CUdeviceptr devicePointer;

  public FloatVector(final @NotNull float[] hostArray) {
    length = hostArray.length;
    devicePointer = JCudaMemory.alloCopy(hostArray);
  }

  @NotNull
  @Override
  public CUdeviceptr reproduce() {
    return JCudaMemory.allocFloat(length);
  }

  @NotNull
  @Override
  public FloatVector set(final @NotNull float[] hostArray) {
    JCudaMemory.copy(hostArray, devicePointer);
    return this;
  }

  @NotNull
  @Override
  public FloatVector reset(final @NotNull float[] hostArray) {
    JCudaMemory.destroy(devicePointer);
    length = hostArray.length;
    devicePointer = JCudaMemory.alloCopy(hostArray);
    return this;
  }

  @NotNull
  @Override
  public float[] get() {
    return JCudaMemory.copy(devicePointer, length);
  }

  @Override
  public void setPointer(final @NotNull CUdeviceptr devicePointer) {
    this.devicePointer = devicePointer;
  }

  @NotNull
  @Override
  public CUdeviceptr getPointer() {
    return devicePointer;
  }

  @Override
  public long length() {
    return length;
  }

  @Override
  public void destroy() {
    JCudaMemory.destroy(devicePointer);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy