All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.SimpleJCublas Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
/*
 * Copyright 2015 Skymind,Inc.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.nd4j.linalg.jcublas;


import jcublas.cublasHandle;
import jcuda.CudaException;
import jcuda.LogLevel;
import jcuda.cuComplex;
import jcuda.cuDoubleComplex;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import jcuda.runtime.cudaError;
import jcuda.runtime.cudaMemcpyKind;
import jcuda.utils.KernelLauncher;

import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.CopyOp;
import org.nd4j.linalg.factory.DataTypeValidation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.kernel.KernelFunctionLoader;

/**
 * Simple abstraction for jcublas operations
 *
 * @author mjk
 * @author Adam Gibson
 */
public class SimpleJCublas {

    private static boolean init = false;
    private static cublasHandle handle = new cublasHandle();


    static {
        init();
    }


    public static void assertCudaBuffer(INDArray... buffer) {
        for (INDArray b1 : buffer)
            if (!(b1.data() instanceof JCudaBuffer))
                throw new IllegalArgumentException("Unable to allocate pointer for buffer of type " + buffer.getClass().toString());
    }


    public static void assertCudaBuffer(DataBuffer... buffer) {
        for (DataBuffer b1 : buffer)
            if (!(b1 instanceof JCudaBuffer))
                throw new IllegalArgumentException("Unable to allocate pointer for buffer of type " + buffer.getClass().toString());
    }

    /**
     * The cublas handle
     *
     * @return the handle used for cublas
     */
    public static cublasHandle handle() {
        return handle;
    }

    static int checkResult(int result)
    {
        if (result != cudaError.cudaSuccess)
        {
            throw new CudaException(cudaError.stringFor(result));
        }
        return result;
    }


    /**
     * Initialize jcublas only called once
     */
    public static void init() {
        if (init)
            return;
//        JCublas2.initialize();
//        cublasHandle handle = new cublasHandle();
//        JCublas2.cublasCreate(handle);

        JCublas.setLogLevel(LogLevel.LOG_DEBUG);
        JCublas.setExceptionsEnabled(true);

        try {
            KernelFunctionLoader.getInstance().load();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

        // Check if the device supports mapped host memory
        cudaDeviceProp deviceProperties = new cudaDeviceProp();
        checkResult(JCuda.cudaGetDeviceProperties(deviceProperties, 0));
        if (deviceProperties.canMapHostMemory == 0) {
            System.err.println("This device can not map host memory");
            System.err.println(deviceProperties.toFormattedString());
            return;
        }
        int[] version = new int[1];
        JCudaDriver.cuCtxGetApiVersion(ContextHolder.getInstance().getContext(), version);


        // Set the flag indicating that mapped memory will be used
        //checkResult(JCuda.cudaSetDeviceFlags(JCuda.cudaDeviceMapHost));

        init = true;
    }


    public static void sync() {
        checkResult(JCuda.cudaDeviceSynchronize());
        KernelLauncher.setContext();
    }

    /**
     * General matrix vector multiplication
     *
     * @param A
     * @param B
     * @param C
     * @param alpha
     * @param beta
     * @return
     */
    public static INDArray gemv(INDArray A, INDArray B, INDArray C, double alpha, double beta) {

        DataTypeValidation.assertDouble(A, B, C);
        assertCudaBuffer(A.data(), B.data(), C.data());
        sync();

        CublasPointer cAPointer = new CublasPointer(A);
        CublasPointer cBPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);

        JCublas.cublasDgemv(
                'N',
                A.rows(),
                A.columns(),
                alpha,
                cAPointer,
                A.rows(),
                cBPointer,
                1,
                beta,
                cCPointer,
                1);

        cCPointer.copyToHost();
        releaseCublasPointers(cAPointer,cBPointer,cCPointer);

        sync();
        return C;
    }

    /**
     * G)eneral matrix vector multiplication
     *
     * @param A
     * @param B
     * @param C
     * @param alpha
     * @param beta
     * @return
     */
    public static INDArray gemv(INDArray A, INDArray B, INDArray C, float alpha, float beta) {

        DataTypeValidation.assertFloat(A, B, C);
        sync();

        CublasPointer cAPointer = new CublasPointer(A);
        CublasPointer cBPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);



        JCublas.cublasSgemv('N',
                A.rows(),
                A.columns(),
                alpha,
                cAPointer,
                A.rows(),
                cBPointer,
                1,
                beta,
                cCPointer,
                1);

        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(cCPointer,cAPointer,cBPointer);

        return C;
    }


    /**
     * General matrix vector
     *
     * @param A
     * @param B
     * @param a
     * @param C
     * @param b
     * @return
     */
    public static IComplexNDArray gemv(IComplexNDArray A, IComplexNDArray B, IComplexDouble a, IComplexNDArray C
            , IComplexDouble b) {
        DataTypeValidation.assertSameDataType(A, B, C);
        sync();


        CublasPointer cAPointer = new CublasPointer(A);
        CublasPointer cBPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);


        cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(a.realComponent().doubleValue(), b.imaginaryComponent().doubleValue());
        cuDoubleComplex beta = cuDoubleComplex.cuCmplx(b.realComponent().doubleValue(), b.imaginaryComponent().doubleValue());

        JCublas.cublasZgemv(
                'n', //trans
                A.rows(),  // m
                A.rows(), // n
                alpha,
                cAPointer, // A
                A.rows(),  // lda
                cBPointer, // x
                B.secondaryStride(), // ldb
                beta,  // beta
                cCPointer, // y
                C.secondaryStride()); // ldc

        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(cAPointer,cBPointer,cCPointer);

        return C;

    }

    /**
     * General matrix vector
     *
     * @param A
     * @param B
     * @param a
     * @param C
     * @param b
     * @return
     */
    public static IComplexNDArray gemv(IComplexNDArray A, IComplexNDArray B, IComplexFloat a, IComplexNDArray C
            , IComplexFloat b) {
        DataTypeValidation.assertFloat(A, B, C);
        assertCudaBuffer(A, B, C);
        sync();

        CublasPointer cAPointer = new CublasPointer(A);
        CublasPointer cBPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);


        cuComplex alpha = cuComplex.cuCmplx(a.realComponent().floatValue(), b.imaginaryComponent().floatValue());
        cuComplex beta = cuComplex.cuCmplx(b.realComponent().floatValue(), b.imaginaryComponent().floatValue());

        JCublas.cublasCgemv(
                'n', //trans
                A.rows(),  // m
                A.columns(), // n
                alpha,
                cAPointer, // A
                A.rows(),  // lda
                cBPointer, // x
                B.secondaryStride(), // ldb
                beta,  // beta
                cCPointer, // y
                C.secondaryStride()); // ldc

        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(cAPointer,cBPointer,cCPointer);

        return C;

    }


    /**
     * General matrix multiply
     *
     * @param A
     * @param B
     * @param a
     * @param C
     * @param b
     * @return
     */
    public static IComplexNDArray gemm(IComplexNDArray A, IComplexNDArray B, IComplexDouble a, IComplexNDArray C
            , IComplexDouble b) {
        DataTypeValidation.assertSameDataType(A, B, C);
        sync();

        CublasPointer cAPointer = new CublasPointer(A);
        CublasPointer cBPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);



        cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(a.realComponent().doubleValue(), b.imaginaryComponent().doubleValue());
        cuDoubleComplex beta = cuDoubleComplex.cuCmplx(b.realComponent().doubleValue(), b.imaginaryComponent().doubleValue());

        JCublas.cublasZgemm(
                'n', //trans
                'n',
                C.rows(),  // m
                C.columns(), // n
                A.columns(), //k,
                alpha,
                cAPointer, // A
                A.rows(),  // lda
                cBPointer, // x
                B.rows(), // ldb
                beta,  // beta
                cCPointer, // y
                C.rows()); // ldc

        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(cAPointer,cBPointer,cCPointer);

        return C;

    }

    /**
     * General matrix multiply
     *
     * @param A
     * @param B
     * @param a
     * @param C
     * @param b
     * @return
     */
    public static IComplexNDArray gemm(IComplexNDArray A, IComplexNDArray B, IComplexFloat a, IComplexNDArray C
            , IComplexFloat b) {
        DataTypeValidation.assertFloat(A, B, C);

        sync();

        CublasPointer cAPointer = new CublasPointer(A);
        CublasPointer cBPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);


        cuComplex alpha = cuComplex.cuCmplx(a.realComponent().floatValue(), b.imaginaryComponent().floatValue());
        cuComplex beta = cuComplex.cuCmplx(b.realComponent().floatValue(), b.imaginaryComponent().floatValue());

        JCublas.cublasCgemm(
                'n', //trans
                'n',
                C.rows(),  // m
                C.columns(), // n
                A.columns(), //k,
                alpha,
                cAPointer, // A
                A.rows(),  // lda
                cBPointer, // x
                B.rows(), // ldb
                beta,  // beta
                cCPointer, // y
                C.rows()); // ldc

        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(cAPointer,cBPointer,cCPointer);

        return C;

    }

    /**
     * General matrix multiply
     *
     * @param A
     * @param B
     * @param C
     * @param alpha
     * @param beta
     * @return
     */
    public static INDArray gemm(INDArray A, INDArray B, INDArray C,
                                double alpha, double beta) {

        DataTypeValidation.assertDouble(A, B, C);

        sync();

        JCublasNDArray cA = (JCublasNDArray) A;
        JCublasNDArray cB = (JCublasNDArray) B;
        JCublasNDArray cC = (JCublasNDArray) C;

        CublasPointer cAPointer = new CublasPointer(A);
        CublasPointer cBPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);


        JCublas.cublasDgemm(
                'n', //trans
                'n',
                C.rows(),  // m
                C.columns(), // n
                A.columns(), //k,
                alpha,
                cAPointer, // A
                A.rows(),  // lda
                cBPointer, // x
                B.rows(), // ldb
                beta,  // beta
                cCPointer, // y
                C.rows()); // incy

        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(cAPointer, cBPointer, cCPointer);

        return C;

    }

    /**
     * General matrix multiply
     *
     * @param A
     * @param B
     * @param C
     * @param alpha
     * @param beta
     * @return
     */
    public static INDArray gemm(INDArray A, INDArray B, INDArray C,
                                float alpha, float beta) {
        DataTypeValidation.assertFloat(A, B, C);
        sync();


        CublasPointer cAPointer = new CublasPointer(A);
        CublasPointer cBPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);

        JCublas.cublasSgemm(
                'n', //trans
                'n',
                C.rows(),  // m
                C.columns(), // n
                A.columns(), //k,
                alpha,
                cAPointer, // A
                A.rows(),  // lda
                cBPointer, // x
                B.rows(), // ldb
                beta,  // beta
                cCPointer, // y
                C.rows()); // incy
        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(cAPointer,cBPointer,cCPointer);

        return C;

    }


    /**
     * Calculate the 2 norm of the ndarray.
     * Note that this is a standin for
     * no complex ndarray. It will treat this as a normal ndarray
     * with a stride of 2.
     *
     * @param A the ndarray to calculate the norm2 of
     * @return the ndarray to calculate the norm2 of
     */
    public static double nrm2(IComplexNDArray A) {

        sync();

        CublasPointer cAPointer = new CublasPointer(A);
        if (A.data().dataType() == DataBuffer.Type.FLOAT) {
            float s = JCublas.cublasSnrm2(A.length(), cAPointer, 2);
            return s;
        } else {
            double s = JCublas.cublasDnrm2(A.length(), cAPointer, 2);
            return s;
        }

    }

    /**
     * Copy x to y
     *
     * @param x the origin
     * @param y the destination
     */
    public static void copy(IComplexNDArray x, IComplexNDArray y) {
        DataTypeValidation.assertSameDataType(x, y);

        sync();

        CublasPointer xCPointer = new CublasPointer(x);
        CublasPointer yCPointer = new CublasPointer(y);


        JCudaBuffer buff = (JCudaBuffer) x.data();
        if (x.majorStride() == 2 && y.majorStride() == 2)
            JCuda.cudaMemcpy(
                    yCPointer
                    , xCPointer
                    , x.length() * buff.getElementSize() * 2
                    , cudaMemcpyKind.cudaMemcpyDeviceToDevice);
        else
            Nd4j.getExecutioner().exec(new CopyOp(x, y, y, x.length()));

        sync();

        yCPointer.copyToHost();
        releaseCublasPointers(yCPointer, xCPointer);


    }


    /**
     * Return the index of the max in the given ndarray
     *
     * @param x the ndarray to ge the max for
     * @return the max index of the given ndarray
     */
    public static int iamax(IComplexNDArray x) {
        sync();

        CublasPointer xCPointer = new CublasPointer(x);
        if (x.data().dataType() == DataBuffer.Type.FLOAT) {
            int max = JCublas.cublasIsamax(x.length(), xCPointer, 1);
            return max;
        } else {
            int max = JCublas.cublasIzamax(x.length(), xCPointer, 1);
            return max;
        }

    }

    /**
     * @param x
     * @return
     */
    public static float asum(IComplexNDArray x) {
        CublasPointer xCPointer = new CublasPointer(x);
        float sum = JCublas.cublasScasum(x.length(), xCPointer, 1);
        return sum;
    }


    /**
     * Swap the elements in each ndarray
     *
     * @param x
     * @param y
     */
    public static void swap(INDArray x, INDArray y) {


        DataTypeValidation.assertSameDataType(x, y);

        CublasPointer xCPointer = new CublasPointer(x);
        CublasPointer yCPointer = new CublasPointer(y);
        sync();

        if (x.data().dataType() == DataBuffer.Type.FLOAT) {
            JCublas.cublasSswap(
                    x.length(),
                    xCPointer,
                    1,
                    yCPointer,
                    1);

        } else {
            JCublas.cublasDswap(
                    x.length(),
                    xCPointer,
                    1,
                    yCPointer,
                    1);

        }
        sync();


    }

    /**
     * @param x
     * @return
     */
    public static double asum(INDArray x) {


        CublasPointer xCPointer = new CublasPointer(x);
        if (x.data().dataType() == DataBuffer.Type.FLOAT) {
            float sum = JCublas.cublasSasum(x.length(), xCPointer, 1);
            return sum;
        } else {
            double sum = JCublas.cublasDasum(x.length(), xCPointer, 1);
            return sum;
        }

    }

    /**
     * Returns the norm2 of the given ndarray
     *
     * @param x
     * @return
     */
    public static double nrm2(INDArray x) {


        if (x.data().dataType() == DataBuffer.Type.FLOAT) {
            CublasPointer xCPointer = new CublasPointer(x);


            float normal2 = JCublas.cublasSnrm2(x.length(), xCPointer, 1);
            return normal2;
        } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
            CublasPointer xCPointer = new CublasPointer(x);
            double normal2 = JCublas.cublasDnrm2(x.length(), xCPointer, 1);
            return normal2;
        }
        throw new IllegalStateException("Illegal data type on array ");


    }

    /**
     * Returns the index of the max element
     * in the given ndarray
     *
     * @param x
     * @return
     */
    public static int iamax(INDArray x) {



        CublasPointer xCPointer = new CublasPointer(x);

        if (x.data().dataType() == DataBuffer.Type.FLOAT) {
            int max = JCublas.cublasIsamax(
                    x.length(),
                    xCPointer,
                    x.majorStride());

            return max - 1;
        } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
            int max = JCublas.cublasIdamax(
                    x.length(),
                    xCPointer,
                    x.majorStride());

            return max - 1;
        }

        throw new IllegalStateException("Illegal data type on array ");
    }


    /**
     * And and scale by the given scalar da
     *
     * @param da alpha
     * @param A  the element to add
     * @param B  the matrix to add to
     */
    public static void axpy(float da, INDArray A, INDArray B) {
        DataTypeValidation.assertFloat(A, B);

        CublasPointer xAPointer = new CublasPointer(A);
        CublasPointer xBPointer = new CublasPointer(B);

        sync();
        JCublas.cublasSaxpy(
                A.length(),
                da,
                xAPointer,
                A.majorStride(),
                xBPointer,
                B.majorStride());

        ((JCudaBuffer)A.data()).copyToHost();
        sync();

        xBPointer.copyToHost();
        releaseCublasPointers(xAPointer, xBPointer);


    }

    /**
     * @param da
     * @param A
     * @param B
     */
    public static void axpy(IComplexFloat da, IComplexNDArray A, IComplexNDArray B) {
        DataTypeValidation.assertFloat(A, B);



        CublasPointer aCPointer = new CublasPointer(A);
        CublasPointer bCPointer = new CublasPointer(B);
        sync();

        JCublas.cublasCaxpy(
                A.length(),
                jcuda.cuComplex.cuCmplx(da.realComponent().floatValue(), da.imaginaryComponent().floatValue()),
                aCPointer,
                1,
                bCPointer,
                1
        );
        sync();


    }

    /**
     * @param da
     * @param A
     * @param B
     */
    public static void axpy(IComplexDouble da, IComplexNDArray A, IComplexNDArray B) {
        DataTypeValidation.assertDouble(A, B);



        CublasPointer aCPointer = new CublasPointer(A);
        CublasPointer bCPointer = new CublasPointer(B);
        sync();

        JCublas.cublasZaxpy(
                A.length(),
                jcuda.cuDoubleComplex.cuCmplx(da.realComponent().floatValue(), da.imaginaryComponent().floatValue()),
                aCPointer,
                A.majorStride(),
                bCPointer,
                B.majorStride()
        );
        sync();


    }


    /**
     * Multiply the given ndarray
     * by alpha
     *
     * @param alpha
     * @param x
     * @return
     */
    public static INDArray scal(double alpha, INDArray x) {
        DataTypeValidation.assertDouble(x);

        sync();

        CublasPointer xCPointer = new CublasPointer(x);
        JCublas.cublasDscal(
                x.length(),
                alpha,
                xCPointer,
                x.majorStride());
        sync();

        xCPointer.copyToHost();
        releaseCublasPointers(xCPointer);

        return x;

    }

    /**
     * Multiply the given ndarray
     * by alpha
     *
     * @param alpha
     * @param x
     * @return
     */
    public static INDArray scal(float alpha, INDArray x) {

        DataTypeValidation.assertFloat(x);
        sync();

        CublasPointer xCPointer = new CublasPointer(x);
        JCublas.cublasSscal(
                x.length(),
                alpha,
                xCPointer,
                x.majorStride());
        sync();

        xCPointer.copyToHost();
        releaseCublasPointers(xCPointer);

        return x;

    }

    /**
     * Copy x to y
     *
     * @param x the src
     * @param y the destination
     */
    public static void copy(INDArray x, INDArray y) {
        DataTypeValidation.assertSameDataType(x, y);
        sync();

        CublasPointer xCPointer = new CublasPointer(x);
        CublasPointer yCPointer = new CublasPointer(y);

        if(x.data().dataType() == DataBuffer.Type.DOUBLE)
            JCublas.cublasDcopy(x.length(),xCPointer,x.majorStride(),yCPointer,y.majorStride());
        if(x.data().dataType() == DataBuffer.Type.FLOAT)
            JCublas.cublasScopy(x.length(),xCPointer,x.majorStride(),yCPointer,y.majorStride());
        sync();

        yCPointer.copyToHost();
        releaseCublasPointers(yCPointer, xCPointer);
    }

    /**
     * Dot product between 2 ndarrays
     *
     * @param x the first ndarray
     * @param y the second ndarray
     * @return the dot product between the two ndarrays
     */
    public static double dot(INDArray x, INDArray y) {
        DataTypeValidation.assertSameDataType(x, y);

        sync();
        CublasPointer xCPointer = new CublasPointer(x);
        CublasPointer yCPointer = new CublasPointer(y);


        if (x.data().dataType() == (DataBuffer.Type.FLOAT)) {
            float ret = JCublas.cublasSdot(
                    x.length(),
                    xCPointer,
                    x.majorStride()
                    , yCPointer,
                    y.majorStride());
            sync();

            releaseCublasPointers(xCPointer,yCPointer);

            return ret;
        } else {
            double ret = JCublas.cublasDdot(
                    x.length(),
                    xCPointer,
                    y.majorStride()
                    , yCPointer,
                    y.majorStride());
            sync();

            releaseCublasPointers(xCPointer,yCPointer);

            return ret;
        }





    }


    private static void releaseCublasPointers(CublasPointer... pointers) {
        for(CublasPointer pointer : pointers)
            try {
                if(pointer != null)
                    pointer.close();
            } catch(Exception e) {
                throw new RuntimeException("Could not run cublas command", e);
            }
    }


    public static IComplexDouble dot(IComplexNDArray x, IComplexNDArray y) {
        DataTypeValidation.assertSameDataType(x, y);

        sync();

        CublasPointer aCPointer = new CublasPointer(x);
        CublasPointer bCPointer = new CublasPointer(y);


        jcuda.cuDoubleComplex dott = JCublas.cublasZdotc(
                x.length(),
                aCPointer,
                x.majorStride(),
                bCPointer,
                y.majorStride());

        IComplexDouble ret = Nd4j.createDouble(dott.x, dott.y);
        sync();

        releaseCublasPointers(aCPointer, bCPointer);
        return ret;
    }


    public static INDArray ger(INDArray A, INDArray B, INDArray C, double alpha) {
        DataTypeValidation.assertDouble(A, B, C);
        sync();

        // = alpha * A * transpose(B) + C
        CublasPointer aCPointer = new CublasPointer(A);
        CublasPointer bCPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);


        JCublas.cublasDger(
                A.rows(),   // m
                A.columns(),// n
                alpha,      // alpha
                aCPointer,        // d_A or x
                A.rows(),   // incx
                bCPointer,        // dB or y
                B.rows(),   // incy
                cCPointer,        // dC or A
                C.rows()    // lda
        );

        cCPointer.copyToHost();
        releaseCublasPointers(aCPointer,bCPointer,cCPointer);

        sync();

        return C;
    }


    public static INDArray ger(INDArray A, INDArray B, INDArray C, float alpha) {
        DataTypeValidation.assertFloat(A, B, C);

        sync();
        // = alpha * A * transpose(B) + C

        CublasPointer aCPointer = new CublasPointer(A);
        CublasPointer bCPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);


        JCublas.cublasSger(
                A.rows(),   // m
                A.columns(),// n
                alpha,      // alpha
                aCPointer,        // d_A or x
                A.rows(),   // incx
                bCPointer,        // dB or y
                B.rows(),   // incy
                cCPointer,        // dC or A
                C.rows()    // lda
        );
        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(aCPointer,bCPointer,cCPointer);

        return C;
    }


    /**
     * Complex multiplication of an ndarray
     *
     * @param alpha
     * @param x
     * @return
     */
    public static IComplexNDArray scal(IComplexFloat alpha, IComplexNDArray x) {
        DataTypeValidation.assertFloat(x);

        sync();

        CublasPointer xCPointer = new CublasPointer(x);

        JCublas.cublasCscal(
                x.length(),
                jcuda.cuComplex.cuCmplx(alpha.realComponent(), alpha.imaginaryComponent()),
                xCPointer,
                x.majorStride()
        );
        sync();

        xCPointer.copyToHost();
        releaseCublasPointers(xCPointer);


        return x;
    }

    /**
     * Complex multiplication of an ndarray
     *
     * @param alpha
     * @param x
     * @return
     */
    public static IComplexNDArray scal(IComplexDouble alpha, IComplexNDArray x) {
        DataTypeValidation.assertDouble(x);
        sync();


        CublasPointer xCPointer = new CublasPointer(x);

        JCublas.cublasZscal(
                x.length(),
                jcuda.cuDoubleComplex.cuCmplx(alpha.realComponent(), alpha.imaginaryComponent()),
                xCPointer,
                x.majorStride()
        );
        sync();

        xCPointer.copyToHost();
        releaseCublasPointers(xCPointer);

        return x;
    }

    /**
     * Complex dot product
     *
     * @param x
     * @param y
     * @return
     */
    public static IComplexDouble dotu(IComplexNDArray x, IComplexNDArray y) {

        DataTypeValidation.assertSameDataType(x, y);
        sync();

        CublasPointer xCPointer = new CublasPointer(x);
        CublasPointer yCPointer = new CublasPointer(y);
        IComplexDouble ret = null;
        if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
            jcuda.cuDoubleComplex dott = JCublas.cublasZdotu(x.length(), xCPointer, x.majorStride(), yCPointer, y.majorStride());
            ret = Nd4j.createDouble(dott.x, dott.y);
        } else {
            jcuda.cuComplex dott = JCublas.cublasCdotu(x.length(), xCPointer, x.majorStride(), yCPointer, y.majorStride());
            ret = Nd4j.createDouble(dott.x, dott.y);
        }
        sync();

        releaseCublasPointers(xCPointer, yCPointer);

        return ret;
    }


    /**
     * @param A
     * @param B
     * @param C
     * @param Alpha
     * @return
     */
    public static IComplexNDArray geru(IComplexNDArray A,
                                       IComplexNDArray B,
                                       IComplexNDArray C, IComplexDouble Alpha) {
        // = alpha * A * tranpose(B) + C
        sync();
        DataTypeValidation.assertDouble(A, B, C);

        CublasPointer aCPointer = new CublasPointer(A);
        CublasPointer bCPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);

        cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(Alpha.realComponent(), Alpha.imaginaryComponent());

        JCublas.cublasZgeru(
                A.rows(),   // m
                A.columns(),// n
                alpha,      // alpha
                aCPointer,        // d_A or x
                A.rows(),   // incx
                bCPointer,        // d_B or y
                B.rows(),   // incy
                cCPointer,        // d_C or A
                C.rows()    // lda
        );
        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(aCPointer,bCPointer,cCPointer);

        return C;
    }

    /**
     * @param A
     * @param B
     * @param C
     * @param Alpha
     * @return
     */
    public static IComplexNDArray gerc(IComplexNDArray A, IComplexNDArray B, IComplexNDArray C,
                                       IComplexFloat Alpha) {
        DataTypeValidation.assertFloat(A, B, C);
        // = alpha * A * tranpose(B) + C

        sync();
        CublasPointer aCPointer = new CublasPointer(A);
        CublasPointer bCPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);


        cuComplex alpha = cuComplex.cuCmplx(Alpha.realComponent(), Alpha.imaginaryComponent());


        JCublas.cublasCgerc(
                A.rows(),   // m
                A.columns(),// n
                alpha,      // alpha
                aCPointer,        // dA or x
                A.rows(),   // incx
                bCPointer,        // dB or y
                B.rows(),   // incy
                cCPointer,        // dC or A
                C.rows()    // lda
        );
        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(aCPointer,bCPointer,cCPointer);

        return C;
    }

    /**
     * @param A
     * @param B
     * @param C
     * @param Alpha
     * @return
     */
    public static IComplexNDArray geru(IComplexNDArray A,
                                       IComplexNDArray B,
                                       IComplexNDArray C, IComplexFloat Alpha) {

        DataTypeValidation.assertFloat(A, B, C);
        // = alpha * A * tranpose(B) + C
        sync();

        CublasPointer aCPointer = new CublasPointer(A);
        CublasPointer bCPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);

        cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(Alpha.realComponent(), Alpha.imaginaryComponent());

        JCublas.cublasZgeru(
                A.rows(),   // m
                A.columns(),// n
                alpha,      // alpha
                aCPointer,        // d_A or x
                A.rows(),   // incx
                bCPointer,        // d_B or y
                B.rows(),   // incy
                cCPointer,        // d_C or A
                C.rows()    // lda
        );

        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(aCPointer,bCPointer,cCPointer);

        return C;
    }

    /**
     * @param A
     * @param B
     * @param C
     * @param Alpha
     * @return
     */
    public static IComplexNDArray gerc(IComplexNDArray A, IComplexNDArray B, IComplexNDArray C,
                                       IComplexDouble Alpha) {

        DataTypeValidation.assertDouble(A, B, C);
        // = alpha * A * tranpose(B) + C

        sync();

        CublasPointer aCPointer = new CublasPointer(A);
        CublasPointer bCPointer = new CublasPointer(B);
        CublasPointer cCPointer = new CublasPointer(C);


        cuDoubleComplex alpha = cuDoubleComplex.cuCmplx(Alpha.realComponent(), Alpha.imaginaryComponent());


        JCublas.cublasZgerc(
                A.rows(),   // m
                A.columns(),// n
                alpha,      // alpha
                aCPointer,        // dA or x
                A.rows(),   // incx
                bCPointer,        // dB or y
                B.rows(),   // incy
                cCPointer,        // dC or A
                C.rows()    // lda
        );

        sync();

        cCPointer.copyToHost();
        releaseCublasPointers(aCPointer,bCPointer,cCPointer);

        return C;
    }

    /**
     * Simpler version of saxpy
     * taking in to account the parameters of the ndarray
     *
     * @param alpha the alpha to scale by
     * @param x     the x
     * @param y     the y
     */
    public static void axpy(double alpha, INDArray x, INDArray y) {
        DataTypeValidation.assertDouble(x, y);

        sync();

        CublasPointer xCPointer = new CublasPointer(x);
        CublasPointer yCPointer = new CublasPointer(y);

        JCublas.cublasDaxpy(x.length(), alpha, xCPointer, x.majorStride(), yCPointer, y.majorStride());

        sync();

        yCPointer.copyToHost();
        releaseCublasPointers(xCPointer, yCPointer);

    }

    /**
     * Simpler version of saxpy
     * taking in to account the parameters of the ndarray
     *
     * @param alpha the alpha to scale by
     * @param x     the x
     * @param y     the y
     */
    public static void saxpy(float alpha, INDArray x, INDArray y) {
        DataTypeValidation.assertFloat(x, y);
        sync();

        CublasPointer xCPointer = new CublasPointer(x);
        CublasPointer yCPointer = new CublasPointer(y);

        JCublas.cublasSaxpy(x.length(), alpha, xCPointer, x.majorStride(), yCPointer, y.majorStride());
        sync();

        xCPointer.copyToHost();
        releaseCublasPointers(xCPointer, yCPointer);


    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy