All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.factory.BlasWrapperTests Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.factory;

import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;

import static org.junit.Assert.*;


/**
 * Created by agibsonccc on 9/11/14.
 */
public abstract class BlasWrapperTests {

    @Test
    public void axpyTest() {
        INDArray a = Nd4j.getBlasWrapper().axpy(1.0,Nd4j.ones(3),Nd4j.ones(3));
        INDArray a2 = Nd4j.create(new float[]{2,2,2});
        assertEquals(a2,a);

        INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2);
        INDArray row = matrix.getRow(1);
        INDArray result = Nd4j.create(new double[]{1,2});
        Nd4j.getBlasWrapper().axpy(1.0,row,result);
        assertEquals(Nd4j.create(new double[]{3,6}),result);



    }

    @Test
    public void testAxpyFortran() {
        Nd4j.factory().setOrder('f');
        INDArray threeByFour = Nd4j.linspace(1,12,12).reshape(3,4);
        INDArray row = threeByFour.getRow(1);
        Nd4j.getBlasWrapper().axpy(2.0,row,row);
    }


    @Test
    public void testIaMax() {
        INDArray test = Nd4j.create(new float[]{1, 2, 3, 4});
        int max = Nd4j.getBlasWrapper().iamax(test);
        assertEquals(3,max);


        INDArray rows = Nd4j.create(new float[]{1, 3, 2, 4}, new int[]{2, 2});
        for(int i = 0; i < rows.rows(); i++) {
            INDArray row = rows.getRow(i);
            int max2 = Nd4j.getBlasWrapper().iamax(row);
            assertEquals(1,max2);
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy