org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.cpu.nativecpu.buffer;
import lombok.val;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.AllocUtil;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import java.nio.ByteBuffer;
import static org.nd4j.linalg.api.buffer.DataType.INT8;
public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallocatable {
protected transient OpaqueDataBuffer ptrDataBuffer;
private transient final long instanceId = Nd4j.getDeallocatorService().nextValue();
protected BaseCpuDataBuffer() {
}
@Override
public String getUniqueId() {
return new String("BCDB_" + instanceId);
}
@Override
public Deallocator deallocator() {
return new CpuDeallocator(this);
}
public OpaqueDataBuffer getOpaqueDataBuffer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return ptrDataBuffer;
}
@Override
public int targetDevice() {
// TODO: once we add NUMA support this might change. Or might not.
return 0;
}
/**
*
* @param length
* @param elementSize
*/
public BaseCpuDataBuffer(long length, int elementSize) {
if (length < 1)
throw new IllegalArgumentException("Length must be >= 1");
initTypeAndSize();
allocationMode = AllocUtil.getAllocationModeFromContext();
this.length = length;
this.underlyingLength = length;
this.elementSize = (byte) elementSize;
if (dataType() != DataType.UTF8)
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, dataType(), false);
if (dataType() == DataType.DOUBLE) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
} else if (dataType() == DataType.FLOAT) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asFloatPointer();
setIndexer(FloatIndexer.create((FloatPointer) pointer));
} else if (dataType() == DataType.INT32) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
setIndexer(IntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.LONG) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
} else if (dataType() == DataType.SHORT) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(ShortIndexer.create((ShortPointer) pointer));
} else if (dataType() == DataType.BYTE) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
} else if (dataType() == DataType.UBYTE) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
setIndexer(UByteIndexer.create((BytePointer) pointer));
} else if (dataType() == DataType.UTF8) {
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, INT8, false);
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
} else if(dataType() == DataType.FLOAT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(HalfIndexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.BFLOAT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.BOOL){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer();
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
} else if(dataType() == DataType.UINT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(UShortIndexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.UINT32){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.UINT64) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
}
Nd4j.getDeallocatorService().pickObject(this);
}
/**
*
* @param length
* @param elementSize
*/
public BaseCpuDataBuffer(int length, int elementSize, long offset) {
this(length, elementSize);
this.offset = offset;
this.originalOffset = offset;
this.length = length - offset;
this.underlyingLength = length;
}
protected BaseCpuDataBuffer(DataBuffer underlyingBuffer, long length, long offset) {
super(underlyingBuffer, length, offset);
// for vew we need "externally managed" pointer and deallocator registration
ptrDataBuffer = ((BaseCpuDataBuffer) underlyingBuffer).ptrDataBuffer.createView(length * underlyingBuffer.getElementSize(), offset * underlyingBuffer.getElementSize());
Nd4j.getDeallocatorService().pickObject(this);
// update pointer now
actualizePointerAndIndexer();
}
protected BaseCpuDataBuffer(ByteBuffer buffer, DataType dtype, long length, long offset) {
this(length, Nd4j.sizeOfDataType(dtype));
Pointer temp = null;
switch (dataType()){
case DOUBLE:
temp = new DoublePointer(buffer.asDoubleBuffer());
break;
case FLOAT:
temp = new FloatPointer(buffer.asFloatBuffer());
break;
case HALF:
temp = new ShortPointer(buffer.asShortBuffer());
break;
case LONG:
temp = new LongPointer(buffer.asLongBuffer());
break;
case INT:
temp = new IntPointer(buffer.asIntBuffer());
break;
case SHORT:
temp = new ShortPointer(buffer.asShortBuffer());
break;
case UBYTE: //Fall through
case BYTE:
temp = new BytePointer(buffer);
break;
case BOOL:
temp = new BooleanPointer(length());
break;
case UTF8:
temp = new BytePointer(length());
break;
case BFLOAT16:
temp = new ShortPointer(length());
break;
case UINT16:
temp = new ShortPointer(length());
break;
case UINT32:
temp = new IntPointer(length());
break;
case UINT64:
temp = new LongPointer(length());
break;
}
val ptr = ptrDataBuffer.primaryBuffer();
if (offset > 0)
temp = new PagedPointer(temp.address() + offset * getElementSize());
Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype));
}
@Override
protected double getDoubleUnsynced(long index) {
return super.getDouble(index);
}
@Override
protected float getFloatUnsynced(long index) {
return super.getFloat(index);
}
@Override
protected long getLongUnsynced(long index) {
return super.getLong(index);
}
@Override
protected int getIntUnsynced(long index) {
return super.getInt(index);
}
@Override
public void pointerIndexerByCurrentType(DataType currentType) {
type = currentType;
if (ptrDataBuffer == null) {
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length(), type, false);
Nd4j.getDeallocatorService().pickObject(this);
}
actualizePointerAndIndexer();
}
/**
* Instantiate a buffer with the given length
*
* @param length the length of the buffer
*/
protected BaseCpuDataBuffer(long length) {
this(length, true);
}
protected BaseCpuDataBuffer(long length, boolean initialize) {
if (length < 0)
throw new IllegalArgumentException("Length must be >= 0");
initTypeAndSize();
this.length = length;
this.underlyingLength = length;
allocationMode = AllocUtil.getAllocationModeFromContext();
if (length < 0)
throw new IllegalArgumentException("Unable to create a buffer of length <= 0");
if (dataType() != DataType.UTF8)
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, dataType(), false);
if (dataType() == DataType.DOUBLE) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.FLOAT) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asFloatPointer();
setIndexer(FloatIndexer.create((FloatPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.HALF) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(HalfIndexer.create((ShortPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.BFLOAT16) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.INT) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
setIndexer(IntIndexer.create((IntPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.LONG) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.BYTE) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.SHORT) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(ShortIndexer.create((ShortPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.UBYTE) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
setIndexer(UByteIndexer.create((BytePointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.UINT16) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(UShortIndexer.create((ShortPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.UINT32) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
setIndexer(UIntIndexer.create((IntPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.UINT64) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.BOOL) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer();
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.UTF8) {
// we are allocating buffer as INT8 intentionally
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length(), INT8, false);
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length()).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
if (initialize)
fillPointerWithZero();
}
Nd4j.getDeallocatorService().pickObject(this);
}
public void actualizePointerAndIndexer() {
val cptr = ptrDataBuffer.primaryBuffer();
// skip update if pointers are equal
if (cptr != null && pointer != null && cptr.address() == pointer.address())
return;
val t = dataType();
if (t == DataType.BOOL) {
pointer = new PagedPointer(cptr, length).asBoolPointer();
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
} else if (t == DataType.UBYTE) {
pointer = new PagedPointer(cptr, length).asBytePointer();
setIndexer(UByteIndexer.create((BytePointer) pointer));
} else if (t == DataType.BYTE) {
pointer = new PagedPointer(cptr, length).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
} else if (t == DataType.UINT16) {
pointer = new PagedPointer(cptr, length).asShortPointer();
setIndexer(UShortIndexer.create((ShortPointer) pointer));
} else if (t == DataType.SHORT) {
pointer = new PagedPointer(cptr, length).asShortPointer();
setIndexer(ShortIndexer.create((ShortPointer) pointer));
} else if (t == DataType.UINT32) {
pointer = new PagedPointer(cptr, length).asIntPointer();
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (t == DataType.INT) {
pointer = new PagedPointer(cptr, length).asIntPointer();
setIndexer(IntIndexer.create((IntPointer) pointer));
} else if (t == DataType.UINT64) {
pointer = new PagedPointer(cptr, length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
} else if (t == DataType.LONG) {
pointer = new PagedPointer(cptr, length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
} else if (t == DataType.BFLOAT16) {
pointer = new PagedPointer(cptr, length).asShortPointer();
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
} else if (t == DataType.HALF) {
pointer = new PagedPointer(cptr, length).asShortPointer();
setIndexer(HalfIndexer.create((ShortPointer) pointer));
} else if (t == DataType.FLOAT) {
pointer = new PagedPointer(cptr, length).asFloatPointer();
setIndexer(FloatIndexer.create((FloatPointer) pointer));
} else if (t == DataType.DOUBLE) {
pointer = new PagedPointer(cptr, length).asDoublePointer();
setIndexer(DoubleIndexer.create((DoublePointer) pointer));
} else if (t == DataType.UTF8) {
pointer = new PagedPointer(cptr, length()).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
} else
throw new IllegalArgumentException("Unknown datatype: " + dataType());
}
@Override
public Pointer addressPointer() {
// we're fetching actual pointer right from C++
val tempPtr = new PagedPointer(ptrDataBuffer.primaryBuffer());
switch (this.type) {
case DOUBLE: return tempPtr.asDoublePointer();
case FLOAT: return tempPtr.asFloatPointer();
case UINT16:
case SHORT:
case BFLOAT16:
case HALF: return tempPtr.asShortPointer();
case UINT32:
case INT: return tempPtr.asIntPointer();
case UBYTE:
case BYTE: return tempPtr.asBytePointer();
case UINT64:
case LONG: return tempPtr.asLongPointer();
case BOOL: return tempPtr.asBoolPointer();
default: return tempPtr.asBytePointer();
}
}
protected BaseCpuDataBuffer(long length, boolean initialize, MemoryWorkspace workspace) {
if (length < 1)
throw new IllegalArgumentException("Length must be >= 1");
initTypeAndSize();
this.length = length;
this.underlyingLength = length;
allocationMode = AllocUtil.getAllocationModeFromContext();
if (length < 0)
throw new IllegalArgumentException("Unable to create a buffer of length <= 0");
if (dataType() == DataType.DOUBLE) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asDoublePointer(); //new DoublePointer(length());
indexer = DoubleIndexer.create((DoublePointer) pointer);
} else if (dataType() == DataType.FLOAT) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asFloatPointer(); //new FloatPointer(length());
setIndexer(FloatIndexer.create((FloatPointer) pointer));
} else if (dataType() == DataType.HALF) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length());
setIndexer(HalfIndexer.create((ShortPointer) pointer));
} else if (dataType() == DataType.BFLOAT16) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length());
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
} else if (dataType() == DataType.INT) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
setIndexer(IntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.UINT32) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.UINT64) {
attached = true;
parentWorkspace = workspace;
// FIXME: need unsigned indexer here
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new IntPointer(length());
setIndexer(LongIndexer.create((LongPointer) pointer));
} else if (dataType() == DataType.LONG) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length());
setIndexer(LongIndexer.create((LongPointer) pointer));
} else if (dataType() == DataType.BYTE) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length());
setIndexer(ByteIndexer.create((BytePointer) pointer));
} else if (dataType() == DataType.UBYTE) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length());
setIndexer(UByteIndexer.create((BytePointer) pointer));
} else if (dataType() == DataType.UINT16) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new IntPointer(length());
setIndexer(UShortIndexer.create((ShortPointer) pointer));
} else if (dataType() == DataType.SHORT) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new LongPointer(length());
setIndexer(ShortIndexer.create((ShortPointer) pointer));
} else if (dataType() == DataType.BOOL) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBoolPointer(); //new LongPointer(length());
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
} else if (dataType() == DataType.UTF8) {
attached = true;
parentWorkspace = workspace;
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length());
setIndexer(LongIndexer.create((LongPointer) pointer));
}
// storing pointer into native DataBuffer
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
// adding deallocator reference
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();
}
public BaseCpuDataBuffer(Pointer pointer, Indexer indexer, long length) {
super(pointer, indexer, length);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);;
}
/**
*
* @param data
* @param copy
*/
public BaseCpuDataBuffer(float[] data, boolean copy, long offset) {
this(data, copy);
this.offset = offset;
this.originalOffset = offset;
this.length = data.length - offset;
this.underlyingLength = data.length;
}
public BaseCpuDataBuffer(float[] data, boolean copy, long offset, MemoryWorkspace workspace) {
this(data, copy, workspace);
this.offset = offset;
this.originalOffset = offset;
this.length = data.length - offset;
this.underlyingLength = data.length;
}
/**
*
* @param data
* @param copy
*/
public BaseCpuDataBuffer(float[] data, boolean copy) {
allocationMode = AllocUtil.getAllocationModeFromContext();
initTypeAndSize();
pointer = new FloatPointer(data);
// creating & registering native DataBuffer
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.FLOAT, false);
ptrDataBuffer.setPrimaryBuffer(pointer, data.length);
Nd4j.getDeallocatorService().pickObject(this);
setIndexer(FloatIndexer.create((FloatPointer) pointer));
//wrappedBuffer = pointer.asByteBuffer();
length = data.length;
underlyingLength = data.length;
}
public BaseCpuDataBuffer(float[] data, boolean copy, MemoryWorkspace workspace) {
allocationMode = AllocUtil.getAllocationModeFromContext();
length = data.length;
underlyingLength = data.length;
attached = true;
parentWorkspace = workspace;
initTypeAndSize();
//log.info("Allocating FloatPointer from array of {} elements", data.length);
pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asFloatPointer().put(data);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();
setIndexer(FloatIndexer.create((FloatPointer) pointer));
//wrappedBuffer = pointer.asByteBuffer();
}
public BaseCpuDataBuffer(double[] data, boolean copy, MemoryWorkspace workspace) {
allocationMode = AllocUtil.getAllocationModeFromContext();
length = data.length;
underlyingLength = data.length;
attached = true;
parentWorkspace = workspace;
initTypeAndSize();
//log.info("Allocating FloatPointer from array of {} elements", data.length);
pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asDoublePointer().put(data);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();
indexer = DoubleIndexer.create((DoublePointer) pointer);
//wrappedBuffer = pointer.asByteBuffer();
}
public BaseCpuDataBuffer(int[] data, boolean copy, MemoryWorkspace workspace) {
allocationMode = AllocUtil.getAllocationModeFromContext();
length = data.length;
underlyingLength = data.length;
attached = true;
parentWorkspace = workspace;
initTypeAndSize();
//log.info("Allocating FloatPointer from array of {} elements", data.length);
pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asIntPointer().put(data);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();
indexer = IntIndexer.create((IntPointer) pointer);
//wrappedBuffer = pointer.asByteBuffer();
}
public BaseCpuDataBuffer(long[] data, boolean copy, MemoryWorkspace workspace) {
allocationMode = AllocUtil.getAllocationModeFromContext();
length = data.length;
underlyingLength = data.length;
attached = true;
parentWorkspace = workspace;
initTypeAndSize();
//log.info("Allocating FloatPointer from array of {} elements", data.length);
pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asLongPointer().put(data);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();
indexer = LongIndexer.create((LongPointer) pointer);
//wrappedBuffer = pointer.asByteBuffer();
}
/**
*
* @param data
* @param copy
*/
public BaseCpuDataBuffer(double[] data, boolean copy, long offset) {
this(data, copy);
this.offset = offset;
this.originalOffset = offset;
this.underlyingLength = data.length;
this.length = underlyingLength - offset;
}
public BaseCpuDataBuffer(double[] data, boolean copy, long offset, MemoryWorkspace workspace) {
this(data, copy, workspace);
this.offset = offset;
this.originalOffset = offset;
this.underlyingLength = data.length;
this.length = underlyingLength - offset;
}
/**
*
* @param data
* @param copy
*/
public BaseCpuDataBuffer(double[] data, boolean copy) {
allocationMode = AllocUtil.getAllocationModeFromContext();
initTypeAndSize();
pointer = new DoublePointer(data);
indexer = DoubleIndexer.create((DoublePointer) pointer);
// creating & registering native DataBuffer
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.DOUBLE, false);
ptrDataBuffer.setPrimaryBuffer(pointer, data.length);
Nd4j.getDeallocatorService().pickObject(this);
length = data.length;
underlyingLength = data.length;
}
/**
*
* @param data
* @param copy
*/
public BaseCpuDataBuffer(int[] data, boolean copy, long offset) {
this(data, copy);
this.offset = offset;
this.originalOffset = offset;
this.length = data.length - offset;
this.underlyingLength = data.length;
}
/**
*
* @param data
* @param copy
*/
public BaseCpuDataBuffer(int[] data, boolean copy) {
allocationMode = AllocUtil.getAllocationModeFromContext();
initTypeAndSize();
pointer = new IntPointer(data);
setIndexer(IntIndexer.create((IntPointer) pointer));
// creating & registering native DataBuffer
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.INT32, false);
ptrDataBuffer.setPrimaryBuffer(pointer, data.length);
Nd4j.getDeallocatorService().pickObject(this);
length = data.length;
underlyingLength = data.length;
}
/**
*
* @param data
* @param copy
*/
public BaseCpuDataBuffer(long[] data, boolean copy) {
allocationMode = AllocUtil.getAllocationModeFromContext();
initTypeAndSize();
pointer = new LongPointer(data);
setIndexer(LongIndexer.create((LongPointer) pointer));
// creating & registering native DataBuffer
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.INT64, false);
ptrDataBuffer.setPrimaryBuffer(pointer, data.length);
Nd4j.getDeallocatorService().pickObject(this);
length = data.length;
underlyingLength = data.length;
}
/**
*
* @param data
*/
public BaseCpuDataBuffer(double[] data) {
this(data, true);
}
/**
*
* @param data
*/
public BaseCpuDataBuffer(int[] data) {
this(data, true);
}
/**
*
* @param data
*/
public BaseCpuDataBuffer(float[] data) {
this(data, true);
}
public BaseCpuDataBuffer(float[] data, MemoryWorkspace workspace) {
this(data, true, workspace);
}
@Override
protected void release() {
ptrDataBuffer.closeBuffer();
super.release();
}
/**
* Reallocate the native memory of the buffer
* @param length the new length of the buffer
* @return this databuffer
* */
@Override
public DataBuffer reallocate(long length) {
val oldPointer = ptrDataBuffer.primaryBuffer();
if (isAttached()) {
val capacity = length * getElementSize();
val nPtr = getParentWorkspace().alloc(capacity, dataType(), false);
this.ptrDataBuffer.setPrimaryBuffer(nPtr, length);
switch (dataType()) {
case BOOL:
pointer = nPtr.asBoolPointer();
indexer = BooleanIndexer.create((BooleanPointer) pointer);
break;
case UTF8:
case BYTE:
case UBYTE:
pointer = nPtr.asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
break;
case UINT16:
case SHORT:
pointer = nPtr.asShortPointer();
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UINT32:
pointer = nPtr.asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
pointer = nPtr.asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
break;
case DOUBLE:
pointer = nPtr.asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
break;
case FLOAT:
pointer = nPtr.asFloatPointer();
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case HALF:
pointer = nPtr.asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case BFLOAT16:
pointer = nPtr.asShortPointer();
indexer = Bfloat16Indexer.create((ShortPointer) pointer);
break;
case UINT64:
case LONG:
pointer = nPtr.asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
break;
}
Pointer.memcpy(pointer, oldPointer, this.length() * getElementSize());
workspaceGenerationId = getParentWorkspace().getGenerationId();
} else {
this.ptrDataBuffer.expand(length);
val nPtr = new PagedPointer(this.ptrDataBuffer.primaryBuffer(), length);
switch (dataType()) {
case BOOL:
pointer = nPtr.asBoolPointer();
indexer = BooleanIndexer.create((BooleanPointer) pointer);
break;
case UTF8:
case BYTE:
case UBYTE:
pointer = nPtr.asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
break;
case UINT16:
case SHORT:
pointer = nPtr.asShortPointer();
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UINT32:
pointer = nPtr.asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
pointer = nPtr.asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
break;
case DOUBLE:
pointer = nPtr.asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
break;
case FLOAT:
pointer = nPtr.asFloatPointer();
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case HALF:
pointer = nPtr.asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case BFLOAT16:
pointer = nPtr.asShortPointer();
indexer = Bfloat16Indexer.create((ShortPointer) pointer);
break;
case UINT64:
case LONG:
pointer = nPtr.asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
break;
}
}
this.underlyingLength = length;
this.length = length;
return this;
}
@Override
public void syncToPrimary(){
ptrDataBuffer.syncToPrimary();
}
@Override
public void syncToSpecial(){
ptrDataBuffer.syncToSpecial();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy