All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.nativelibs4java.opencl.blas.CLDefaultMatrix2D Maven / Gradle / Ivy

/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package com.nativelibs4java.opencl.blas;

import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLMem.Usage;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.util.Primitive;

import org.bridj.Pointer;

/**
 *
 * @author ochafik
 */
public class CLDefaultMatrix2D implements CLMatrix2D {
    protected final Primitive primitive;
    protected final Class primitiveClass;
    protected final long rows, columns, stride, length;
    protected final int blockSize;
    
    protected final CLKernels kernels;
    protected final CLBuffer buffer;
    protected final CLQueue queue;
    protected final CLContext context;
    protected CLEvents _events = new CLEvents();

    public static final int DEFAULT_BLOCK_SIZE = 16;

    public CLDefaultMatrix2D(Primitive primitive, CLBuffer buffer, long rows, long columns, CLKernels kernels) {
      this(primitive, buffer, rows, columns, DEFAULT_BLOCK_SIZE, kernels);
    }
    public CLDefaultMatrix2D(Primitive primitive, CLBuffer buffer, long rows, long columns, int blockSize, CLKernels kernels) {
        this.primitive = primitive;
        this.primitiveClass = (Class)primitive.primitiveType;
        this.stride = CLMatrixUtils.roundUp(columns, blockSize);
        this.length = this.stride * CLMatrixUtils.roundUp(rows, blockSize);
        if (buffer != null) {
            if (buffer.getElementCount() < this.length) {
                throw new IllegalArgumentException("Buffer size too small; buffer of size " + this.length + " expected, size " + buffer.getByteCount() + " was given");
            }
            this.buffer = buffer;
        } else {
            this.buffer = (CLBuffer)kernels.getContext().createBuffer(Usage.InputOutput, primitive.primitiveType, length);
        }
        this.kernels = kernels;
        this.rows = rows;
        this.columns = columns;
        this.queue = kernels.getQueue();
        this.context = kernels.getContext();
        this.blockSize = blockSize;
        
        assert getBuffer().getElementCount() >= stride * rows &&
            getBuffer().getElementCount() <= stride * CLMatrixUtils.roundUp(rows, getBlockSize());
    }
    
    public CLMatrix2D blankClone() {
        return blankMatrix(getRowCount(), getColumnCount());
    }
    public CLMatrix2D blankMatrix(long rows, long columns) {
        return new CLDefaultMatrix2D(primitive, null, rows, columns, blockSize, kernels);
    }

    public long getRowCount() {
        return rows;
    }

    public long getColumnCount() {
        return columns;
    }

    public long getStride() {
        return stride;
    }

    public int getBlockSize() {
        return blockSize;
    }

    public CLEvents getEvents() {
        return _events;
    }
    
    public void write(final Pointer b) {
        getEvents().performWrite(new CLEvents.Action() {
            public CLEvent perform(CLEvent[] events) {
                return buffer.write(queue, b, false, events);
            }
        });
    }

    public void read(final Pointer b) {
        getEvents().performRead(new CLEvents.Action() {
            public CLEvent perform(CLEvent[] events) {
                return buffer.read(queue, b, true, events);
            }
        });
    }
    public Pointer read() {
        Pointer out = Pointer.allocateArray(primitiveClass, length);
        read(out);
        return out;
    }
    
    
    public CLBuffer getBuffer() {
        return buffer;
    }

    public CLContext getContext() {
        return context;
    }

    public synchronized CLQueue getQueue() {
        return queue;
    }

    /*
    public synchronized void setQueue(CLQueue queue) {
        if (this.queue != null && queue != null) {
            if (this.queue.equals(queue))
                return;
        }
        getEvents().waitFor();
        this.queue = queue;
    }
     * */

    public Primitive getPrimitive() {
        return primitive;
    }

    public Class getPrimitiveClass() {
        return primitiveClass;
    }

    public CLKernels getKernels() {
        return kernels;
    }
    
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy