![JAR search and dependency download from the Maven repository](/logo.png)
com.nativelibs4java.opencl.blas.CLDefaultMatrix2D Maven / Gradle / Ivy
/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package com.nativelibs4java.opencl.blas;
import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLMem.Usage;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.util.Primitive;
import org.bridj.Pointer;
/**
*
* @author ochafik
*/
public class CLDefaultMatrix2D implements CLMatrix2D {
protected final Primitive primitive;
protected final Class primitiveClass;
protected final long rows, columns, stride, length;
protected final int blockSize;
protected final CLKernels kernels;
protected final CLBuffer buffer;
protected final CLQueue queue;
protected final CLContext context;
protected CLEvents _events = new CLEvents();
public static final int DEFAULT_BLOCK_SIZE = 16;
public CLDefaultMatrix2D(Primitive primitive, CLBuffer buffer, long rows, long columns, CLKernels kernels) {
this(primitive, buffer, rows, columns, DEFAULT_BLOCK_SIZE, kernels);
}
public CLDefaultMatrix2D(Primitive primitive, CLBuffer buffer, long rows, long columns, int blockSize, CLKernels kernels) {
this.primitive = primitive;
this.primitiveClass = (Class)primitive.primitiveType;
this.stride = CLMatrixUtils.roundUp(columns, blockSize);
this.length = this.stride * CLMatrixUtils.roundUp(rows, blockSize);
if (buffer != null) {
if (buffer.getElementCount() < this.length) {
throw new IllegalArgumentException("Buffer size too small; buffer of size " + this.length + " expected, size " + buffer.getByteCount() + " was given");
}
this.buffer = buffer;
} else {
this.buffer = (CLBuffer)kernels.getContext().createBuffer(Usage.InputOutput, primitive.primitiveType, length);
}
this.kernels = kernels;
this.rows = rows;
this.columns = columns;
this.queue = kernels.getQueue();
this.context = kernels.getContext();
this.blockSize = blockSize;
assert getBuffer().getElementCount() >= stride * rows &&
getBuffer().getElementCount() <= stride * CLMatrixUtils.roundUp(rows, getBlockSize());
}
public CLMatrix2D blankClone() {
return blankMatrix(getRowCount(), getColumnCount());
}
public CLMatrix2D blankMatrix(long rows, long columns) {
return new CLDefaultMatrix2D(primitive, null, rows, columns, blockSize, kernels);
}
public long getRowCount() {
return rows;
}
public long getColumnCount() {
return columns;
}
public long getStride() {
return stride;
}
public int getBlockSize() {
return blockSize;
}
public CLEvents getEvents() {
return _events;
}
public void write(final Pointer b) {
getEvents().performWrite(new CLEvents.Action() {
public CLEvent perform(CLEvent[] events) {
return buffer.write(queue, b, false, events);
}
});
}
public void read(final Pointer b) {
getEvents().performRead(new CLEvents.Action() {
public CLEvent perform(CLEvent[] events) {
return buffer.read(queue, b, true, events);
}
});
}
public Pointer read() {
Pointer out = Pointer.allocateArray(primitiveClass, length);
read(out);
return out;
}
public CLBuffer getBuffer() {
return buffer;
}
public CLContext getContext() {
return context;
}
public synchronized CLQueue getQueue() {
return queue;
}
/*
public synchronized void setQueue(CLQueue queue) {
if (this.queue != null && queue != null) {
if (this.queue.equals(queue))
return;
}
getEvents().waitFor();
this.queue = queue;
}
* */
public Primitive getPrimitive() {
return primitive;
}
public Class getPrimitiveClass() {
return primitiveClass;
}
public CLKernels getKernels() {
return kernels;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy