All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.examples.MultiINDArrayInterop Maven / Gradle / Ivy

package org.nd4j.examples;

import jcuda.jcublas.JCublas;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.BlasWrapper;
import org.nd4j.linalg.factory.NDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jblas.JblasNDArrayFactory;
import org.nd4j.linalg.jcublas.JCublasNDArrayFactory;
import org.nd4j.linalg.jcublas.JCublasWrapper;

/**
 * Created by agibsonccc on 9/15/14.
 */
public class MultiINDArrayInterop {

    static NDArrayFactory jblas = new JblasNDArrayFactory("double",'f');
    static NDArrayFactory jcublas = new JCublasNDArrayFactory("double",'f');
    static BlasWrapper wrapper = new org.nd4j.linalg.jblas.BlasWrapper();
    static  BlasWrapper jcublasWrapper = new JCublasWrapper();


    public static void main(String[] args) {

        INDArray jblasLinspace = jblas.linspace(1,8,8);
        INDArray jcublasLinspace = jcublas.linspace(1,8,8);

        setJblas();

        INDArray transpose = jblasLinspace.transpose();

        setJcublas();

        INDArray jcublastranspose = jcublasLinspace.transpose();
        setJblas();

        INDArray mmul = jblasLinspace.mmul(transpose);
        setJcublas();

        INDArray mmul2 = jcublastranspose.mmul(jcublastranspose);
        assert mmul.equals(mmul2);



        setJblas();
        INDArray jblasreshape = jblasLinspace.reshape(2,4);
        INDArray reshapetranspose = jblasreshape.transpose();
        INDArray jblasmmul = reshapetranspose.mmul(jblasreshape);
        setJcublas();
        INDArray jcublasreshape = jcublasLinspace.reshape(2, 4);
        INDArray jcublasreshapetranspose = jblasreshape.transpose();
        assert reshapetranspose.equals(jcublasreshapetranspose);
        assert jcublasreshape.equals(jblasreshape);

        INDArray jcublasmmul = jcublasreshapetranspose.mmul(jcublasreshape);

        assert jblasmmul.equals(jcublasmmul);




        setJblas();
        INDArray toTransposeJblas = jblas.create(new float[]{1,2,3,4},new int[]{2,2});
        setJcublas();
        INDArray toTransposeJcublas = jcublas.create(new float[]{1,2,3,4},new int[]{2,2});
        setJblas();
        INDArray transposeToTransposeJblas = toTransposeJblas.transpose();
        setJcublas();
        INDArray transposeToTransposeJcublas = toTransposeJcublas.transpose();
        assert toTransposeJblas.equals(toTransposeJcublas);
        assert transposeToTransposeJblas.equals(transposeToTransposeJcublas);

        setJblas();
        double dotjblas = Nd4j.getBlasWrapper().dot(jblasLinspace,jblasLinspace);
        setJcublas();
        double  dotjcublas = Nd4j.getBlasWrapper().dot(jcublasLinspace,jcublasLinspace);
        assert dotjblas == dotjcublas;



    }

    public static void setJcublas() {
        Nd4j.setFactory(jcublas);
        Nd4j.setBlasWrapper(jcublasWrapper);

    }



    public static void setJblas() {
        Nd4j.setFactory(jblas);
        Nd4j.setBlasWrapper(wrapper);

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy