org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.jcublas.buffer;
import lombok.Getter;
import lombok.NonNull;
import lombok.val;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.nd4j.common.base.Preconditions;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.impl.CudaDeallocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.memory.MemcpyDirection;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.util.LongUtils;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.nio.*;
import java.util.Collection;
/**
* Base class for a data buffer
*
* CUDA implementation for DataBuffer always uses JavaCPP
* as allocationMode, and device access is masked by
* appropriate allocator mover implementation.
*
* Memory allocation/deallocation is strictly handled by allocator,
* since JavaCPP alloc/dealloc has nothing to do with CUDA.
* But besides that, host pointers obtained from CUDA are 100%
* compatible with CPU
*
* @author Adam Gibson
* @author [email protected]
*/
public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCudaBuffer, Deallocatable {
protected OpaqueDataBuffer ptrDataBuffer;
@Getter
protected transient volatile AllocationPoint allocationPoint;
private static AtomicAllocator allocator = AtomicAllocator.getInstance();
private static Logger log = LoggerFactory.getLogger(BaseCudaDataBuffer.class);
protected DataType globalType = DataTypeUtil.getDtypeFromContext();
public BaseCudaDataBuffer() {
}
public OpaqueDataBuffer getOpaqueDataBuffer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return ptrDataBuffer;
}
public BaseCudaDataBuffer(@NonNull Pointer pointer, @NonNull Pointer specialPointer, @NonNull Indexer indexer, long length) {
this.allocationMode = AllocationMode.MIXED_DATA_TYPES;
this.indexer = indexer;
this.offset = 0;
this.originalOffset = 0;
this.underlyingLength = length;
this.length = length;
initTypeAndSize();
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, this.type, pointer, specialPointer);
this.allocationPoint = new AllocationPoint(ptrDataBuffer, this.type.width() * length);
Nd4j.getDeallocatorService().pickObject(this);if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
}
/**
* Meant for creating another view of a buffer
*
* @param pointer the underlying buffer to create a view from
* @param indexer the indexer for the pointer
* @param length the length of the view
*/
public BaseCudaDataBuffer(Pointer pointer, Indexer indexer, long length) {
super(pointer, indexer, length);
// allocating interop buffer
this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false);
// passing existing pointer to native holder
this.ptrDataBuffer.setPrimaryBuffer(pointer, length);
//cuda specific bits
this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize);
Nd4j.getDeallocatorService().pickObject(this);
// now we're getting context and copying our stuff to device
val context = AtomicAllocator.getInstance().getDeviceContext();
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
context.getSpecialStream().synchronize();
}
public BaseCudaDataBuffer(float[] data, boolean copy) {
//super(data, copy);
this(data, copy, 0);
}
public BaseCudaDataBuffer(float[] data, boolean copy, MemoryWorkspace workspace) {
//super(data, copy);
this(data, copy, 0, workspace);
}
public BaseCudaDataBuffer(float[] data, boolean copy, long offset) {
this(data.length, 4, false);
this.offset = offset;
this.originalOffset = offset;
this.length = data.length - offset;
this.underlyingLength = data.length;
set(data, this.length, offset, offset);
}
public BaseCudaDataBuffer(double[] data, boolean copy, long offset, MemoryWorkspace workspace) {
this(data.length, 8, false, workspace);
this.offset = offset;
this.originalOffset = offset;
this.length = data.length - offset;
this.underlyingLength = data.length;
set(data, this.length, offset, offset);
}
public BaseCudaDataBuffer(float[] data, boolean copy, long offset, MemoryWorkspace workspace) {
this(data.length, 4,false, workspace);
this.offset = offset;
this.originalOffset = offset;
this.length = data.length - offset;
this.underlyingLength = data.length;
set(data, this.length, offset, offset);
}
public BaseCudaDataBuffer(double[] data, boolean copy) {
this(data, copy, 0);
}
public BaseCudaDataBuffer(double[] data, boolean copy, long offset) {
this(data.length, 8, false);
this.offset = offset;
this.originalOffset = offset;
this.length = data.length - offset;
this.underlyingLength = data.length;
set(data, this.length, offset, offset);
}
public BaseCudaDataBuffer(int[] data, boolean copy) {
this(data, copy, 0);
}
public BaseCudaDataBuffer(int[] data, boolean copy, MemoryWorkspace workspace) {
this(data, copy, 0, workspace);
}
public BaseCudaDataBuffer(int[] data, boolean copy, long offset) {
this(data.length, 4, false);
this.offset = offset;
this.originalOffset = offset;
this.length = data.length - offset;
this.underlyingLength = data.length;
set(data, this.length, offset, offset);
}
public BaseCudaDataBuffer(int[] data, boolean copy, long offset, MemoryWorkspace workspace) {
this(data.length, 4, false, workspace);
this.offset = offset;
this.originalOffset = offset;
this.length = data.length - offset;
this.underlyingLength = data.length;
set(data, this.length, offset, offset);
}
protected void initPointers(long length, DataType dtype, boolean initialize) {
initPointers(length, Nd4j.sizeOfDataType(dtype), initialize);
}
public void lazyAllocateHostPointer() {
if (length() == 0)
return;
// java side might be unaware of native-side buffer allocation
if (this.indexer == null || this.pointer == null || this.pointer.address() == 0) {
initHostPointerAndIndexer();
} else if (allocationPoint.getHostPointer() != null && allocationPoint.getHostPointer().address() != this.pointer.address()) {
initHostPointerAndIndexer();
}
}
protected BaseCudaDataBuffer(ByteBuffer buffer, DataType dtype, long length, long offset) {
this(length, Nd4j.sizeOfDataType(dtype));
Pointer temp = null;
switch (dataType()){
case DOUBLE:
temp = new DoublePointer(buffer.asDoubleBuffer());
break;
case FLOAT:
temp = new FloatPointer(buffer.asFloatBuffer());
break;
case HALF:
temp = new ShortPointer(buffer.asShortBuffer());
break;
case LONG:
temp = new LongPointer(buffer.asLongBuffer());
break;
case INT:
temp = new IntPointer(buffer.asIntBuffer());
break;
case SHORT:
temp = new ShortPointer(buffer.asShortBuffer());
break;
case UBYTE: //Fall through
case BYTE:
temp = new BytePointer(buffer);
break;
case BOOL:
temp = new BooleanPointer(length());
break;
case UTF8:
temp = new BytePointer(length());
break;
case BFLOAT16:
temp = new ShortPointer(length());
break;
case UINT16:
temp = new ShortPointer(length());
break;
case UINT32:
temp = new IntPointer(length());
break;
case UINT64:
temp = new LongPointer(length());
break;
}
// copy data to device
val stream = AtomicAllocator.getInstance().getDeviceContext().getSpecialStream();
val ptr = ptrDataBuffer.specialBuffer();
if (offset > 0)
temp = new PagedPointer(temp.address() + offset * getElementSize());
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(ptr, temp, length * Nd4j.sizeOfDataType(dtype), CudaConstants.cudaMemcpyHostToDevice, stream);
stream.synchronize();
// mark device buffer as updated
allocationPoint.tickDeviceWrite();
}
protected void initHostPointerAndIndexer() {
if (length() == 0)
return;
if (allocationPoint.getHostPointer() == null) {
val location = allocationPoint.getAllocationStatus();
if (parentWorkspace == null) {
// let cpp allocate primary buffer
NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(ptrDataBuffer);
} else {
//log.info("ws alloc step");
val ptr = parentWorkspace.alloc(this.length * this.elementSize, MemoryKind.HOST, this.dataType(), false);
ptrDataBuffer.setPrimaryBuffer(ptr, this.length);
}
this.allocationPoint.setAllocationStatus(location);
this.allocationPoint.tickDeviceWrite();
}
val hostPointer = allocationPoint.getHostPointer();
assert hostPointer != null;
switch (dataType()) {
case DOUBLE:
this.pointer = new CudaPointer(hostPointer, length, 0).asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
break;
case FLOAT:
this.pointer = new CudaPointer(hostPointer, length, 0).asFloatPointer();
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case UINT32:
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
break;
case BFLOAT16:
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
indexer = Bfloat16Indexer.create((ShortPointer) pointer);
break;
case HALF:
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case UINT64: //Fall through
case LONG:
this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
break;
case UINT16:
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
indexer = UShortIndexer.create((ShortPointer) pointer);
break;
case SHORT:
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UBYTE:
this.pointer = new CudaPointer(hostPointer, length, 0).asBytePointer();
indexer = UByteIndexer.create((BytePointer) pointer);
break;
case BYTE:
this.pointer = new CudaPointer(hostPointer, length, 0).asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
break;
case BOOL:
this.pointer = new CudaPointer(hostPointer, length, 0).asBooleanPointer();
indexer = BooleanIndexer.create((BooleanPointer) pointer);
break;
case UTF8:
this.pointer = new CudaPointer(hostPointer, length, 0).asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
break;
default:
throw new UnsupportedOperationException();
}
}
protected void initPointers(long length, int elementSize, boolean initialize) {
this.allocationMode = AllocationMode.MIXED_DATA_TYPES;
this.length = length;
this.elementSize = (byte) elementSize;
this.offset = 0;
this.originalOffset = 0;
// we allocate native DataBuffer AND it will contain our device pointer
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false);
this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * type.width());
if (initialize) {
val ctx = AtomicAllocator.getInstance().getDeviceContext();
val devicePtr = allocationPoint.getDevicePointer();
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream());
ctx.getSpecialStream().synchronize();
}
// let deallocator pick up this object
Nd4j.getDeallocatorService().pickObject(this);
}
public BaseCudaDataBuffer(long length, int elementSize, boolean initialize) {
initTypeAndSize();
initPointers(length, elementSize, initialize);
}
public BaseCudaDataBuffer(long length, int elementSize, boolean initialize, @NonNull MemoryWorkspace workspace) {
this.allocationMode = AllocationMode.MIXED_DATA_TYPES;
initTypeAndSize();
this.attached = true;
this.parentWorkspace = workspace;
this.length = length;
this.offset = 0;
this.originalOffset = 0;
if (workspace.getWorkspaceConfiguration().getPolicyMirroring() == MirroringPolicy.FULL) {
val devicePtr = workspace.alloc(length * elementSize, MemoryKind.DEVICE, type, initialize);
// allocate from workspace, and pass it to native DataBuffer
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(this.length, type, null, devicePtr);
if (initialize) {
val ctx = AtomicAllocator.getInstance().getDeviceContext();
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream());
ctx.getSpecialStream().synchronize();
}
} else {
// we can register this pointer as device, because it's pinned memory
val devicePtr = workspace.alloc(length * elementSize, MemoryKind.HOST, type, initialize);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(this.length, type, null, devicePtr);
if (initialize) {
val ctx = AtomicAllocator.getInstance().getDeviceContext();
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream());
ctx.getSpecialStream().synchronize();
}
}
this.allocationPoint = new AllocationPoint(ptrDataBuffer, elementSize * length);
// registering for deallocation
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();
this.attached = true;
this.parentWorkspace = workspace;
}
@Override
protected void setIndexer(Indexer indexer) {
//TODO: to be abstracted
this.indexer = indexer;
}
/**
* Base constructor. It's used within all constructors internally
*
* @param length the length of the buffer
* @param elementSize the size of each element
*/
public BaseCudaDataBuffer(long length, int elementSize) {
this(length, elementSize, true);
}
public BaseCudaDataBuffer(long length, int elementSize, MemoryWorkspace workspace) {
this(length, elementSize, true, workspace);
}
public BaseCudaDataBuffer(long length, int elementSize, long offset) {
this(length, elementSize);
this.offset = offset;
this.originalOffset = offset;
}
public BaseCudaDataBuffer(@NonNull DataBuffer underlyingBuffer, long length, long offset) {
if (underlyingBuffer.wasClosed())
throw new IllegalStateException("You can't use DataBuffer once it was released");
//this(length, underlyingBuffer.getElementSize(), offset);
this.allocationMode = AllocationMode.MIXED_DATA_TYPES;
initTypeAndSize();
this.wrappedDataBuffer = underlyingBuffer;
this.originalBuffer = underlyingBuffer.originalDataBuffer() == null ? underlyingBuffer
: underlyingBuffer.originalDataBuffer();
this.length = length;
this.offset = offset;
this.originalOffset = offset;
this.elementSize = (byte) underlyingBuffer.getElementSize();
// in case of view creation, we initialize underlying buffer regardless of anything
((BaseCudaDataBuffer) underlyingBuffer).lazyAllocateHostPointer();
// we're creating view of the native DataBuffer
ptrDataBuffer = ((BaseCudaDataBuffer) underlyingBuffer).ptrDataBuffer.createView(length * underlyingBuffer.getElementSize(), offset * underlyingBuffer.getElementSize());
this.allocationPoint = new AllocationPoint(ptrDataBuffer, length);
val hostPointer = allocationPoint.getHostPointer();
Nd4j.getDeallocatorService().pickObject(this);
switch (underlyingBuffer.dataType()) {
case DOUBLE:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
break;
case FLOAT:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asFloatPointer();
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case UINT32:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
break;
case BFLOAT16:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
indexer = Bfloat16Indexer.create((ShortPointer) pointer);
break;
case HALF:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case UINT64: //Fall through
case LONG:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
break;
case UINT16:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
indexer = UShortIndexer.create((ShortPointer) pointer);
break;
case SHORT:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case BOOL:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBooleanPointer();
indexer = BooleanIndexer.create((BooleanPointer) pointer);
break;
case BYTE:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
break;
case UBYTE:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBytePointer();
indexer = UByteIndexer.create((BytePointer) pointer);
break;
case UTF8:
Preconditions.checkArgument(offset == 0, "String array can't be a view");
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
break;
default:
throw new UnsupportedOperationException();
}
}
public BaseCudaDataBuffer(long length) {
this(length, Nd4j.sizeOfDataType(Nd4j.dataType()));
}
public BaseCudaDataBuffer(float[] data) {
//super(data);
this(data.length, Nd4j.sizeOfDataType(DataType.FLOAT), false);
set(data, data.length, 0, 0);
}
public BaseCudaDataBuffer(int[] data) {
//super(data);
this(data.length, Nd4j.sizeOfDataType(DataType.INT), false);
set(data, data.length, 0, 0);
}
public BaseCudaDataBuffer(long[] data) {
//super(data);
this(data.length, Nd4j.sizeOfDataType(DataType.LONG), false);
set(data, data.length, 0, 0);
}
public BaseCudaDataBuffer(long[] data, boolean copy) {
//super(data);
this(data.length, Nd4j.sizeOfDataType(DataType.LONG), false);
if (copy)
set(data, data.length, 0, 0);
}
public BaseCudaDataBuffer(double[] data) {
// super(data);
this(data.length, Nd4j.sizeOfDataType(DataType.DOUBLE), false);
set(data, data.length, 0, 0);
}
/**
* This method always returns host pointer
*
* @return
*/
@Override
public long address() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return allocationPoint.getHostPointer().address();
}
@Override
public long platformAddress() {
return allocationPoint.getDevicePointer().address();
}
@Override
public Pointer pointer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
// FIXME: very bad thing,
lazyAllocateHostPointer();
return super.pointer();
}
/**
*
* PLEASE NOTE: length, srcOffset, dstOffset are considered numbers of elements, not byte offsets
*
* @param data
* @param length
* @param srcOffset
* @param dstOffset
*/
public void set(int[] data, long length, long srcOffset, long dstOffset) {
// TODO: make sure getPointer returns proper pointer
switch (dataType()) {
case BOOL: {
val pointer = new BytePointer(ArrayUtil.toBytes(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case BYTE: {
val pointer = new BytePointer(ArrayUtil.toBytes(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case UBYTE: {
for (int e = 0; e < data.length; e++) {
put(e, data[e]);
}
}
break;
case SHORT: {
val pointer = new ShortPointer(ArrayUtil.toShorts(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case INT: {
val pointer = new IntPointer(data);
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case LONG: {
val pointer = new LongPointer(LongUtils.toLongs(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case HALF: {
val pointer = new ShortPointer(ArrayUtil.toHalfs(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case FLOAT: {
val pointer = new FloatPointer(ArrayUtil.toFloats(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case DOUBLE: {
val pointer = new DoublePointer(ArrayUtil.toDouble(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
default:
throw new UnsupportedOperationException("Unsupported data type: " + dataType());
}
}
public void set(long[] data, long length, long srcOffset, long dstOffset) {
// TODO: make sure getPointer returns proper pointer
switch (dataType()) {
case BOOL: {
val pointer = new BytePointer(ArrayUtil.toBytes(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case BYTE: {
val pointer = new BytePointer(ArrayUtil.toBytes(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case UBYTE: {
data = ArrayUtil.cutBelowZero(data);
for (int e = 0; e < data.length; e++) {
put(e, data[e]);
}
}
break;
case UINT16:
data = ArrayUtil.cutBelowZero(data);
case SHORT: {
val pointer = new ShortPointer(ArrayUtil.toShorts(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case UINT32:
data = ArrayUtil.cutBelowZero(data);
case INT: {
val pointer = new IntPointer(ArrayUtil.toInts(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case UINT64:
data = ArrayUtil.cutBelowZero(data);
case LONG: {
val pointer = new LongPointer(data);
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case BFLOAT16: {
val pointer = new ShortPointer(ArrayUtil.toBfloats(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case HALF: {
val pointer = new ShortPointer(ArrayUtil.toHalfs(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case FLOAT: {
val pointer = new FloatPointer(ArrayUtil.toFloats(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case DOUBLE: {
val pointer = new DoublePointer(ArrayUtil.toDouble(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
default:
throw new UnsupportedOperationException("Unsupported data type: " + dataType());
}
}
/**
*
* PLEASE NOTE: length, srcOffset, dstOffset are considered numbers of elements, not byte offsets
*
* @param data
* @param length
* @param srcOffset
* @param dstOffset
*/
public void set(float[] data, long length, long srcOffset, long dstOffset) {
switch (dataType()) {
case BOOL: {
val pointer = new BytePointer(ArrayUtil.toBytes(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case BYTE: {
val pointer = new BytePointer(ArrayUtil.toBytes(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case UBYTE: {
for (int e = 0; e < data.length; e++) {
put(e, data[e]);
}
}
break;
case SHORT: {
val pointer = new ShortPointer(ArrayUtil.toShorts(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case INT: {
val pointer = new IntPointer(ArrayUtil.toInts(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case LONG: {
val pointer = new LongPointer(ArrayUtil.toLongArray(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case HALF: {
val pointer = new ShortPointer(ArrayUtil.toHalfs(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case FLOAT: {
val pointer = new FloatPointer(data);
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case DOUBLE: {
DoublePointer pointer = new DoublePointer(ArrayUtil.toDoubles(data));
Pointer srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
default:
throw new UnsupportedOperationException("Unsupported data type: " + dataType());
}
}
/**
*
* PLEASE NOTE: length, srcOffset, dstOffset are considered numbers of elements, not byte offsets
*
* @param data
* @param length
* @param srcOffset
* @param dstOffset
*/
public void set(double[] data, long length, long srcOffset, long dstOffset) {
switch (dataType()) {
case BOOL: {
val pointer = new BytePointer(ArrayUtil.toBytes(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case BYTE: {
val pointer = new BytePointer(ArrayUtil.toBytes(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case UBYTE: {
for (int e = 0; e < data.length; e++) {
put(e, data[e]);
}
}
break;
case SHORT: {
val pointer = new ShortPointer(ArrayUtil.toShorts(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case INT: {
val pointer = new IntPointer(ArrayUtil.toInts(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case LONG: {
val pointer = new LongPointer(ArrayUtil.toLongs(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case HALF: {
val pointer = new ShortPointer(ArrayUtil.toHalfs(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case FLOAT: {
val pointer = new FloatPointer(ArrayUtil.toFloats(data));
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
case DOUBLE: {
val pointer = new DoublePointer(data);
val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize));
allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize);
// we're keeping pointer reference for JVM
pointer.address();
}
break;
default:
throw new UnsupportedOperationException("Unsupported data type: " + dataType());
}
}
@Override
public void setData(int[] data) {
if (data.length == 0)
return;
set(data, data.length, 0, 0);
}
@Override
public void setData(long[] data) {
if (data.length == 0)
return;
set(data, data.length, 0, 0);
}
@Override
public void setData(float[] data) {
if (data.length == 0)
return;
set(data, data.length, 0, 0);
}
@Override
public void setData(double[] data) {
if (data.length == 0)
return;
set(data, data.length, 0, 0);
}
@Override
protected void setNioBuffer() {
throw new UnsupportedOperationException("setNioBuffer() is not supported for CUDA backend");
}
@Override
public void copyAtStride(DataBuffer buf, long n, long stride, long yStride, long offset, long yOffset) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
allocator.synchronizeHostData(buf);
super.copyAtStride(buf, n, stride, yStride, offset, yOffset);
}
@Override
public AllocationMode allocationMode() {
return allocationMode;
}
@Override
public ByteBuffer getHostBuffer() {
return pointer.asByteBuffer();
}
@Override
public Pointer getHostPointer() {
return AtomicAllocator.getInstance().getHostPointer(this);
}
@Override
public Pointer getHostPointer(long offset) {
throw new UnsupportedOperationException();
}
@Override
public void removeReferencing(String id) {
//referencing.remove(id);
}
@Override
public Collection references() {
//return referencing;
return null;
}
@Override
public int getElementSize() {
return elementSize;
}
@Override
public void addReferencing(String id) {
//referencing.add(id);
}
@Deprecated
public Pointer getHostPointer(INDArray arr, int stride, long offset, int length) {
throw new UnsupportedOperationException("This method is deprecated");
}
@Deprecated
public void set(Pointer pointer) {
throw new UnsupportedOperationException("set(Pointer) is not supported");
}
@Override
public void put(long i, float element) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
allocator.tickHostWrite(this);
super.put(i, element);
}
@Override
public void put(long i, boolean element) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
allocator.tickHostWrite(this);
super.put(i, element);
}
@Override
public void put(long i, double element) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
allocator.tickHostWrite(this);
super.put(i, element);
}
@Override
public void put(long i, int element) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
allocator.tickHostWrite(this);
super.put(i, element);
}
@Override
public void put(long i, long element) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
allocator.tickHostWrite(this);
super.put(i, element);
}
@Override
public Pointer addressPointer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return AtomicAllocator.getInstance().getHostPointer(this);
}
/**
* Set an individual element
*
* @param index the index of the element
* @param from the element to get data from
*/
@Deprecated
protected void set(long index, long length, Pointer from, long inc) {
long offset = getElementSize() * index;
if (offset >= length() * getElementSize())
throw new IllegalArgumentException(
"Illegal offset " + offset + " with index of " + index + " and length " + length());
// TODO: fix this
throw new UnsupportedOperationException("Deprecated set() call");
}
/**
* Set an individual element
*
* @param index the index of the element
* @param from the element to get data from
*/
@Deprecated
protected void set(long index, long length, Pointer from) {
set(index, length, from, 1);
}
@Override
public void assign(DataBuffer data) {
/*JCudaBuffer buf = (JCudaBuffer) data;
set(0, buf.getHostPointer());
*/
/*
memcpyAsync(
new Pointer(allocator.getPointer(this).address()),
new Pointer(allocator.getPointer(data).address()),
data.length()
);*/
allocator.memcpy(this, data);
}
@Override
public void assign(long[] indices, float[] data, boolean contiguous, long inc) {
if (indices.length != data.length)
throw new IllegalArgumentException("Indices and data length must be the same");
if (indices.length > length())
throw new IllegalArgumentException("More elements than space to assign. This buffer is of length "
+ length() + " where the indices are of length " + data.length);
// TODO: eventually consider memcpy here
for (int i = 0; i < indices.length; i++) {
put(indices[i], data[i]);
}
}
@Override
public void assign(long[] indices, double[] data, boolean contiguous, long inc) {
if (indices.length != data.length)
throw new IllegalArgumentException("Indices and data length must be the same");
if (indices.length > length())
throw new IllegalArgumentException("More elements than space to assign. This buffer is of length "
+ length() + " where the indices are of length " + data.length);
// TODO: eventually consider memcpy here
for (int i = 0; i < indices.length; i++) {
put(indices[i], data[i]);
}
}
/**
* Set an individual element
*
* @param index the index of the element
* @param from the element to get data from
*/
@Deprecated
protected void set(long index, Pointer from) {
set(index, 1, from);
}
@Override
public void flush() {
//
}
@Override
public void destroy() {}
@Override
protected double getDoubleUnsynced(long index) {
return super.getDouble(index);
}
@Override
protected float getFloatUnsynced(long index) {
return super.getFloat(index);
}
@Override
protected long getLongUnsynced(long index) {
return super.getLong(index);
}
@Override
protected int getIntUnsynced(long index) {
return super.getInt(index);
}
@Override
public void write(DataOutputStream out) throws IOException {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
super.write(out);
}
@Override
public void write(OutputStream dos) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
super.write(dos);
}
private void writeObject(java.io.ObjectOutputStream stream) throws IOException {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
stream.defaultWriteObject();
write(stream);
}
private void readObject(java.io.ObjectInputStream stream) throws IOException, ClassNotFoundException {
doReadObject(stream);
}
@Override
public String toString() {
lazyAllocateHostPointer();
AtomicAllocator.getInstance().synchronizeHostData(this);
return super.toString();
}
@Override
public boolean sameUnderlyingData(DataBuffer buffer) {
return ptrDataBuffer.address() == ((BaseCudaDataBuffer) buffer).ptrDataBuffer.address();
}
/**
* PLEASE NOTE: this method implies STRICT equality only.
* I.e: this == object
*
* @param o
* @return
*/
@Override
public boolean equals(Object o) {
if (o == null)
return false;
if (this == o)
return true;
return false;
}
@Override
public void read(InputStream is, AllocationMode allocationMode, long length, DataType dataType) {
if (allocationPoint == null) {
initPointers(length, dataType, false);
}
super.read(is, allocationMode, length, dataType);
this.allocationPoint.tickHostWrite();
}
@Override
public void pointerIndexerByCurrentType(DataType currentType) {
//
/*
switch (currentType) {
case LONG:
pointer = new LongPointer(length());
setIndexer(LongIndexer.create((LongPointer) pointer));
type = DataType.LONG;
break;
case INT:
pointer = new IntPointer(length());
setIndexer(IntIndexer.create((IntPointer) pointer));
type = DataType.INT;
break;
case DOUBLE:
pointer = new DoublePointer(length());
indexer = DoubleIndexer.create((DoublePointer) pointer);
break;
case FLOAT:
pointer = new FloatPointer(length());
setIndexer(FloatIndexer.create((FloatPointer) pointer));
break;
case HALF:
pointer = new ShortPointer(length());
setIndexer(HalfIndexer.create((ShortPointer) pointer));
break;
case COMPRESSED:
break;
default:
throw new UnsupportedOperationException();
}
*/
}
//@Override
public void read(DataInputStream s) {
try {
val savedMode = AllocationMode.valueOf(s.readUTF());
allocationMode = AllocationMode.MIXED_DATA_TYPES;
long locLength = 0;
if (savedMode.ordinal() < 3)
locLength = s.readInt();
else
locLength = s.readLong();
boolean reallocate = locLength != length || indexer == null;
length = locLength;
val t = DataType.valueOf(s.readUTF());
// log.info("Restoring buffer ["+t+"] of length ["+ length+"]");
if (globalType == null && Nd4j.dataType() != null) {
globalType = Nd4j.dataType();
}
if (t == DataType.COMPRESSED) {
type = t;
return;
}
this.elementSize = (byte) Nd4j.sizeOfDataType(t);
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, t), false);
this.type = t;
Nd4j.getDeallocatorService().pickObject(this);
switch (type) {
case DOUBLE: {
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
}
break;
case FLOAT: {
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asFloatPointer();
indexer = FloatIndexer.create((FloatPointer) pointer);
}
break;
case HALF: {
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
}
break;
case LONG: {
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
}
break;
case INT: {
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
}
break;
case SHORT: {
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asShortPointer();
indexer = ShortIndexer.create((ShortPointer) pointer);
}
break;
case UBYTE: {
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asBytePointer();
indexer = UByteIndexer.create((BytePointer) pointer);
}
break;
case BYTE: {
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
}
break;
case BOOL: {
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asBooleanPointer();
indexer = BooleanIndexer.create((BooleanPointer) pointer);
}
break;
default:
throw new UnsupportedOperationException("Unsupported data type: " + type);
}
readContent(s, t, t);
allocationPoint.tickHostWrite();
} catch (Exception e) {
throw new RuntimeException(e);
}
// we call sync to copyback data to host
AtomicAllocator.getInstance().getFlowController().synchronizeToDevice(allocationPoint);
//allocator.synchronizeHostData(this);
}
@Override
public byte[] asBytes() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asBytes();
}
@Override
public double[] asDouble() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asDouble();
}
@Override
public float[] asFloat() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asFloat();
}
@Override
public int[] asInt() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asInt();
}
@Override
public long[] asLong() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asLong();
}
@Override
public ByteBuffer asNio() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asNio();
}
@Override
public DoubleBuffer asNioDouble() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asNioDouble();
}
@Override
public FloatBuffer asNioFloat() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asNioFloat();
}
@Override
public IntBuffer asNioInt() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asNioInt();
}
@Override
public DataBuffer dup() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
DataBuffer buffer = create(this.length);
allocator.memcpyBlocking(buffer, new CudaPointer(allocator.getHostPointer(this).address()), this.length * elementSize, 0);
return buffer;
}
@Override
public Number getNumber(long i) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.getNumber(i);
}
@Override
public double getDouble(long i) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.getDouble(i);
}
@Override
public long getLong(long i) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.getLong(i);
}
@Override
public float getFloat(long i) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.getFloat(i);
}
@Override
public int getInt(long ix) {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.getInt(ix);
}
public void actualizePointerAndIndexer() {
val cptr = ptrDataBuffer.primaryBuffer();
// skip update if pointers are equal
if (cptr != null && pointer != null && cptr.address() == pointer.address())
return;
val t = dataType();
if (t == DataType.BOOL) {
pointer = new PagedPointer(cptr, length).asBoolPointer();
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
} else if (t == DataType.UBYTE) {
pointer = new PagedPointer(cptr, length).asBytePointer();
setIndexer(UByteIndexer.create((BytePointer) pointer));
} else if (t == DataType.BYTE) {
pointer = new PagedPointer(cptr, length).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
} else if (t == DataType.UINT16) {
pointer = new PagedPointer(cptr, length).asShortPointer();
setIndexer(UShortIndexer.create((ShortPointer) pointer));
} else if (t == DataType.SHORT) {
pointer = new PagedPointer(cptr, length).asShortPointer();
setIndexer(ShortIndexer.create((ShortPointer) pointer));
} else if (t == DataType.UINT32) {
pointer = new PagedPointer(cptr, length).asIntPointer();
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (t == DataType.INT) {
pointer = new PagedPointer(cptr, length).asIntPointer();
setIndexer(IntIndexer.create((IntPointer) pointer));
} else if (t == DataType.UINT64) {
pointer = new PagedPointer(cptr, length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
} else if (t == DataType.LONG) {
pointer = new PagedPointer(cptr, length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
} else if (t == DataType.BFLOAT16) {
pointer = new PagedPointer(cptr, length).asShortPointer();
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
} else if (t == DataType.HALF) {
pointer = new PagedPointer(cptr, length).asShortPointer();
setIndexer(HalfIndexer.create((ShortPointer) pointer));
} else if (t == DataType.FLOAT) {
pointer = new PagedPointer(cptr, length).asFloatPointer();
setIndexer(FloatIndexer.create((FloatPointer) pointer));
} else if (t == DataType.DOUBLE) {
pointer = new PagedPointer(cptr, length).asDoublePointer();
setIndexer(DoubleIndexer.create((DoublePointer) pointer));
} else if (t == DataType.UTF8) {
pointer = new PagedPointer(cptr, length()).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
} else
throw new IllegalArgumentException("Unknown datatype: " + dataType());
}
@Override
public DataBuffer reallocate(long length) {
val oldHostPointer = this.ptrDataBuffer.primaryBuffer();
val oldDevicePointer = this.ptrDataBuffer.specialBuffer();
if (isAttached()) {
val capacity = length * getElementSize();
if (oldDevicePointer != null && oldDevicePointer.address() != 0) {
val nPtr = getParentWorkspace().alloc(capacity, MemoryKind.DEVICE, dataType(), false);
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpySync(nPtr, oldDevicePointer, length * getElementSize(), 3, null);
this.ptrDataBuffer.setPrimaryBuffer(nPtr, length);
allocationPoint.tickDeviceRead();
}
if (oldHostPointer != null && oldHostPointer.address() != 0) {
val nPtr = getParentWorkspace().alloc(capacity, MemoryKind.HOST, dataType(), false);
Pointer.memcpy(nPtr, oldHostPointer, this.length() * getElementSize());
this.ptrDataBuffer.setPrimaryBuffer(nPtr, length);
allocationPoint.tickHostRead();
switch (dataType()) {
case BOOL:
pointer = nPtr.asBoolPointer();
indexer = BooleanIndexer.create((BooleanPointer) pointer);
break;
case UTF8:
case BYTE:
case UBYTE:
pointer = nPtr.asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
break;
case UINT16:
case SHORT:
pointer = nPtr.asShortPointer();
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UINT32:
pointer = nPtr.asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
pointer = nPtr.asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
break;
case DOUBLE:
pointer = nPtr.asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
break;
case FLOAT:
pointer = nPtr.asFloatPointer();
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case HALF:
pointer = nPtr.asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case BFLOAT16:
pointer = nPtr.asShortPointer();
indexer = Bfloat16Indexer.create((ShortPointer) pointer);
break;
case UINT64:
case LONG:
pointer = nPtr.asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
break;
}
}
workspaceGenerationId = getParentWorkspace().getGenerationId();
} else {
this.ptrDataBuffer.expand(length);
val nPtr = new PagedPointer(this.ptrDataBuffer.primaryBuffer(), length);
switch (dataType()) {
case BOOL:
pointer = nPtr.asBoolPointer();
indexer = BooleanIndexer.create((BooleanPointer) pointer);
break;
case UTF8:
case BYTE:
case UBYTE:
pointer = nPtr.asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
break;
case UINT16:
case SHORT:
pointer = nPtr.asShortPointer();
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UINT32:
pointer = nPtr.asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
pointer = nPtr.asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
break;
case DOUBLE:
pointer = nPtr.asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
break;
case FLOAT:
pointer = nPtr.asFloatPointer();
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case HALF:
pointer = nPtr.asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case BFLOAT16:
pointer = nPtr.asShortPointer();
indexer = Bfloat16Indexer.create((ShortPointer) pointer);
break;
case UINT64:
case LONG:
pointer = nPtr.asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
break;
}
}
this.underlyingLength = length;
this.length = length;
return this;
}
@Override
public long capacity() {
if (allocationPoint.getHostPointer() != null)
return pointer.capacity();
else
return length;
}
@Override
protected void release() {
if (!released) {
ptrDataBuffer.closeBuffer();
allocationPoint.setReleased(true);
}
super.release();
}
/*
protected short fromFloat( float fval ) {
int fbits = Float.floatToIntBits( fval );
int sign = fbits >>> 16 & 0x8000; // sign only
int val = ( fbits & 0x7fffffff ) + 0x1000; // rounded value
if( val >= 0x47800000 ) // might be or become NaN/Inf
{ // avoid Inf due to rounding
if( ( fbits & 0x7fffffff ) >= 0x47800000 )
{ // is or must become NaN/Inf
if( val < 0x7f800000 ) // was value but too large
return (short) (sign | 0x7c00); // make it +/-Inf
return (short) (sign | 0x7c00 | // remains +/-Inf or NaN
( fbits & 0x007fffff ) >>> 13); // keep NaN (and Inf) bits
}
return (short) (sign | 0x7bff); // unrounded not quite Inf
}
if( val >= 0x38800000 ) // remains normalized value
return (short) (sign | val - 0x38000000 >>> 13); // exp - 127 + 15
if( val < 0x33000000 ) // too small for subnormal
return (short) sign; // becomes +/-0
val = ( fbits & 0x7fffffff ) >>> 23; // tmp exp for subnormal calc
return (short) (sign | ( ( fbits & 0x7fffff | 0x800000 ) // add subnormal bit
+ ( 0x800000 >>> val - 102 ) // round depending on cut off
>>> 126 - val )); // div by 2^(1-(exp-127+15)) and >> 13 | exp=0
}
*/
@Override
public String getUniqueId() {
return "BCDB_" + allocationPoint.getObjectId();
}
/**
* This method returns deallocator associated with this instance
* @return
*/
@Override
public Deallocator deallocator() {
return new CudaDeallocator(this);
}
@Override
public int targetDevice() {
return AtomicAllocator.getInstance().getAllocationPoint(this).getDeviceId();
}
@Override
public void syncToPrimary(){
ptrDataBuffer.syncToPrimary();
}
@Override
public void syncToSpecial(){
ptrDataBuffer.syncToSpecial();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy