All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.nativelibs4java.opencl.blas.ujmp.CLDenseFloatMatrix2D Maven / Gradle / Ivy

/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */

package com.nativelibs4java.opencl.blas.ujmp;

import com.nativelibs4java.opencl.blas.CLMatrix2D;
import com.nativelibs4java.opencl.blas.CLMatrixUtils;
import com.nativelibs4java.opencl.blas.CLDefaultMatrix2D;
import com.nativelibs4java.opencl.blas.CLKernels;
import com.nativelibs4java.opencl.blas.CLEvents.Action;
import java.nio.DoubleBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation.Ret;
import org.ujmp.core.doublematrix.DoubleMatrix2D;
import org.ujmp.core.exceptions.MatrixException;

import com.nativelibs4java.opencl.CLBuildException;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLMem.MapFlags;
import com.nativelibs4java.opencl.CLMem.Usage;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.util.LinearAlgebraUtils;
import com.nativelibs4java.opencl.util.Primitive;
import com.nativelibs4java.util.NIOUtils;
import java.nio.Buffer;
import org.bridj.Pointer;
import org.ujmp.core.doublematrix.stub.AbstractDenseDoubleMatrix2D;
import org.ujmp.core.floatmatrix.FloatMatrix2D;
import org.ujmp.core.floatmatrix.stub.AbstractDenseFloatMatrix2D;
import org.ujmp.core.matrix.Matrix2D;

/**
 *
 * @author ochafik
 */
public class CLDenseFloatMatrix2D extends AbstractDenseFloatMatrix2D {
	
    protected final CLDenseMatrix2DImpl impl;

    public CLDenseMatrix2DImpl getImpl() {
        return impl;
    }
    public CLDenseFloatMatrix2D(CLDenseMatrix2DImpl impl) {
        this.impl = impl;
    }
    public CLDenseFloatMatrix2D(CLMatrix2D matrix) {
        this(new CLDenseMatrix2DImpl(matrix));
    }
    public CLDenseFloatMatrix2D(long rows, long columns, CLKernels clUJMP) {
        this(new CLDefaultMatrix2D(Primitive.Float, null, rows, columns, clUJMP));
    }
    public CLDenseFloatMatrix2D(long rows, long columns, CLKernels clUJMP, int blockSize) {
        this(new CLDefaultMatrix2D(Primitive.Float, null, rows, columns, blockSize, clUJMP));
    }
    public CLDenseFloatMatrix2D(long rows, long columns) {
        this(rows, columns, CLKernels.getInstance());
    }
    public CLDenseFloatMatrix2D(long size) {
        this(size, size);
    }
    public CLDenseFloatMatrix2D(long... size) {
        this(size[0], size[1], CLKernels.getInstance());
    }

    public void write(Pointer p) {
        getImpl().write(p);
    }

    public void read(Pointer p) {
        getImpl().read(p);
    }
    
    public long getStride() {
        return getImpl().getStride();
    }

    public Pointer read() {
        return getImpl().read();
    }

    static CLDenseFloatMatrix2D inst(CLMatrix2D matrix) {
        return new CLDenseFloatMatrix2D(matrix);
    }
    
    static CLDenseFloatMatrix2D inst(CLDenseMatrix2DImpl matrix) {
        return new CLDenseFloatMatrix2D(matrix);
    }
    
    @Override
    public Matrix mtimes(Ret returnType, boolean ignoreNaN, Matrix matrix) throws MatrixException {
        if (matrix instanceof Matrix2D) {
            return inst(getImpl().multiply(returnType, ignoreNaN, (Matrix2D)matrix));
        } else {
            return super.mtimes(returnType, ignoreNaN, matrix);
        }
    }
    
    @Override
    public Matrix mtimes(Matrix matrix) throws MatrixException {
        return mtimes(Ret.NEW, true, matrix);
    }
    

    @Override
    public Iterable allValues() {
        return (Pointer)impl.read();
    }


    @Override
    public Matrix min(Ret returnType, int dimension) throws MatrixException {
        switch (dimension) {
            case ROW:
            case COLUMN:
                // TODO
                return super.min(returnType, dimension);
            case ALL:
                try {
                    return inst(impl.min());
                } catch (CLBuildException ex) {
                    throw new MatrixException("Failed to compute min", ex);
                }
            default:
                throw new IllegalArgumentException("Invalid dimension : " + dimension);
        }
    }

    @Override
    public Matrix max(Ret returnType, int dimension) throws MatrixException {
        switch (dimension) {
            case ROW:
            case COLUMN:
                // TODO
                return super.max(returnType, dimension);
            case ALL:
                try {
                    return inst(impl.max());
                } catch (CLBuildException ex) {
                    throw new MatrixException("Failed to compute max", ex);
                }
            default:
                throw new IllegalArgumentException("Invalid dimension : " + dimension);
        }
    }

    @Override
    public Matrix mean(Ret returnType, int dimension, boolean ignoreNaN) throws MatrixException {
        // TODO
        return super.mean(returnType, dimension, ignoreNaN);
    }

    @Override
    public Matrix center(Ret returnType, int dimension, boolean ignoreNaN) throws MatrixException {
        switch (dimension) {
            case ROW:
            case COLUMN:
                // TODO
                return super.center(returnType, dimension, ignoreNaN);
            case ALL:
                return minus(returnType, ignoreNaN, mean(Ret.NEW, dimension, ignoreNaN).getAsFloat(0, 0));
            default:
                throw new IllegalArgumentException("Invalid dimension : " + dimension);
        }
    }

    @Override
    public synchronized Matrix copy() throws MatrixException {
        return inst(impl.clone());
    }

    @Override
    public Matrix transpose(Ret returnType) throws MatrixException {
        return inst(impl.transpose(returnType));
    }
    
    @Override
    public synchronized Matrix transpose() throws MatrixException {
        return transpose(Ret.NEW);
    }

    public long[] getSize() {
        return impl.getSize();
    }

    public float getFloat(long row, long column) {
        return impl.get(row, column);
    }
    public void setFloat(float value, long row, long column) {
        impl.set(value, row, column);
    }
    
    public float getFloat(int row, int column) {
        return getFloat((long)row, (long)column);
    }

    public void setFloat(float value, int row, int column) {
        setFloat(value, (long)row, (long)column);
    }

    

    @Override
    public Matrix sin(Ret returnType) throws MatrixException {
        return inst(impl.sin(returnType));
    }

    @Override
    public Matrix cos(Ret returnType) throws MatrixException {
        return inst(impl.cos(returnType));
    }

    @Override
    public Matrix sinh(Ret returnType) throws MatrixException {
        return inst(impl.sinh(returnType));
    }

    @Override
    public Matrix cosh(Ret returnType) throws MatrixException {
        return inst(impl.cosh(returnType));
    }

    @Override
    public Matrix tan(Ret returnType) throws MatrixException {
        return inst(impl.tan(returnType));
    }

    @Override
    public Matrix tanh(Ret returnType) throws MatrixException {
        return inst(impl.tanh(returnType));
    }

    //@Override
    public Matrix asin(Ret returnType) throws MatrixException {
        return inst(impl.asin(returnType));
    }

    //@Override
    public Matrix acos(Ret returnType) throws MatrixException {
        return inst(impl.acos(returnType));
    }

    //@Override
    public Matrix asinh(Ret returnType) throws MatrixException {
        return inst(impl.asinh(returnType));
    }

    //@Override
    public Matrix acosh(Ret returnType) throws MatrixException {
        return inst(impl.acosh(returnType));
    }

    //@Override
    public Matrix atan(Ret returnType) throws MatrixException {
        return inst(impl.atan(returnType));
    }

    //@Override
    public Matrix atanh(Ret returnType) throws MatrixException {
        return inst(impl.atanh(returnType));
    }

    @Override
    public Matrix minus(Matrix m) throws MatrixException {
        return minus(Ret.NEW, true, m);
    }
    
    @Override
    public Matrix minus(double m) throws MatrixException {
        return minus(Ret.NEW, true, m);
    }
    
    @Override
    public Matrix minus(Ret returnType, boolean ignoreNaN, Matrix m) throws MatrixException {
        return inst(impl.minus(returnType, ignoreNaN, ((CLDenseFloatMatrix2D)m).getImpl()));
    }
    @Override
    public Matrix minus(Ret returnType, boolean ignoreNaN, double v) throws MatrixException {
        return inst(impl.minus(returnType, ignoreNaN, (float)v));
    }

    @Override
    public Matrix plus(Matrix m) throws MatrixException {
        return plus(Ret.NEW, true, m);
    }
    
    @Override
    public Matrix plus(double m) throws MatrixException {
        return plus(Ret.NEW, true, m);
    }
    
    @Override
    public Matrix plus(Ret returnType, boolean ignoreNaN, Matrix m) throws MatrixException {
        return inst(impl.plus(returnType, ignoreNaN, ((CLDenseFloatMatrix2D)m).getImpl()));
    }

    @Override
    public Matrix plus(Ret returnType, boolean ignoreNaN, double v) throws MatrixException {
        return inst(impl.plus(returnType, ignoreNaN, (float)v));
    }

    @Override
    public Matrix times(Matrix m) throws MatrixException {
        return times(Ret.NEW, true, m);
    }
    
    @Override
    public Matrix times(double m) throws MatrixException {
        return times(Ret.NEW, true, m);
    }
    
    @Override
    public Matrix times(Ret returnType, boolean ignoreNaN, Matrix factor) throws MatrixException {
        return inst(impl.times(returnType, ignoreNaN, ((CLDenseFloatMatrix2D)factor).getImpl()));
    }

    @Override
    public Matrix times(Ret returnType, boolean ignoreNaN, double factor) throws MatrixException {
        return inst(impl.times(returnType, ignoreNaN, (float)factor));
    }

    @Override
    public Matrix divide(Matrix m) throws MatrixException {
        return divide(Ret.NEW, true, m);
    }
    
    @Override
    public Matrix divide(double m) throws MatrixException {
        return divide(Ret.NEW, true, m);
    }
    
    @Override
    public Matrix divide(Ret returnType, boolean ignoreNaN, Matrix factor) throws MatrixException {
        return inst(impl.divide(returnType, ignoreNaN, ((CLDenseFloatMatrix2D)factor).getImpl()));
    }

    @Override
    public Matrix divide(Ret returnType, boolean ignoreNaN, double factor) throws MatrixException {
        return inst(impl.divide(returnType, ignoreNaN, (float)factor));
    }

    @Override
    public boolean containsDouble(double v) {
        return containsFloat((float)v);
    }
    
    
    @Override
    public boolean containsFloat(float v) {
        try {
            return impl.containsValue((float)v);
        } catch (CLBuildException ex) {
            throw new RuntimeException("Failed to test value presence", ex);
        }
    }

    @Override
    public void clear() {
        try {
            impl.clear();
        } catch (CLBuildException ex) {
            throw new RuntimeException("Failed to clear matrix", ex);
        }
    }

    public void waitFor() {
        impl.waitFor();
    }
    
    
    @Override
    public float[][] toFloatArray() throws MatrixException {
        Pointer b = impl.read();
        float[][] ret = new float[(int)impl.rows][];
        for (int i = 0; i < impl.rows; i++) {
            ret[i] = b.getFloatsAtOffset(i * impl.columns, (int)impl.columns);
        }
        return ret;
    }

    @Override
    public double[][] toDoubleArray() throws MatrixException {
        Pointer b = impl.read();
        double[][] ret = new double[(int)impl.rows][(int)impl.columns];
        for (int i = 0; i < impl.rows; i++) {
            float[] floats = b.getFloatsAtOffset(i * impl.columns, (int) impl.columns);
            double[] doubles = ret[i];
            for (int j = 0; j < floats.length; j++)
                doubles[j] = floats[j];
        }
        return ret;
    }
}