Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.linalg.api.test.NDArrayTests Maven / Gradle / Ivy
package org.nd4j.linalg.api.test;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.DimensionSlice;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.SliceOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.reduceops.Ops;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.Shape;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.*;
/**
* NDArrayTests
* @author Adam Gibson
*/
public abstract class NDArrayTests {
private static Logger log = LoggerFactory.getLogger(NDArrayTests.class);
private INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[]{2, 2, 2});
@Before
public void before() {
Nd4j.factory().setOrder('c');
}
@After
public void after() {
Nd4j.factory().setOrder('c');
}
@Test
public void testScalarOps() {
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new int[]{3, 3, 3});
assertEquals(27d,n.length(),1e-1);
n.checkDimensions(n.addi(Nd4j.scalar(1d)));
n.checkDimensions(n.subi(Nd4j.scalar(1.0d)));
n.checkDimensions(n.muli(Nd4j.scalar(1.0d)));
n.checkDimensions(n.divi(Nd4j.scalar(1.0d)));
n = Nd4j.create(Nd4j.ones(27).data(), new int[]{3, 3, 3});
assertEquals(27, n.sum(Integer.MAX_VALUE).getDouble(0),1e-1);
INDArray a = n.slice(2);
assertEquals(true,Arrays.equals(new int[]{3,3},a.shape()));
}
@Test
public void testSubiRowVector() {
INDArray oneThroughFour = Nd4j.linspace(1,4,4).reshape(2,2);
INDArray row1 = oneThroughFour.getRow(1);
oneThroughFour.subiRowVector(row1);
INDArray result = Nd4j.create(new float[]{-2,-2,0,0},new int[]{2,2});
assertEquals(result,oneThroughFour);
}
@Test
public void testSort() {
INDArray toSort = Nd4j.linspace(1,4,4).reshape(2,2);
INDArray ascending = Nd4j.sort(toSort.dup(),1,true);
//rows already already sorted
assertEquals(toSort,ascending);
INDArray columnSorted = Nd4j.create(new float[]{2,1,4,3},new int[]{2,2});
INDArray sorted = Nd4j.sort(toSort.dup(),1,false);
assertEquals(columnSorted,sorted);
}
@Test
public void testAddVectorWithOffset() {
INDArray oneThroughFour = Nd4j.linspace(1,4,4).reshape(2,2);
INDArray row1 = oneThroughFour.getRow(1);
row1.addi(1);
INDArray result = Nd4j.create(new float[]{1,2,4,5},new int[]{2,2});
assertEquals(result,oneThroughFour);
}
@Test
public void testLinearViewGetAndPut() {
INDArray test = Nd4j.linspace(1, 4, 4).reshape(2,2);
INDArray linear = test.linearView();
linear.putScalar(2,6);
linear.putScalar(3,7);
assertEquals(6,linear.getFloat(2),1e-1);
assertEquals(7,linear.getFloat(3),1e-1);
}
@Test
public void testGetIndices() {
/*[[[1.0 ,13.0],[5.0 ,17.0],[9.0 ,21.0]],[[2.0 ,14.0],[6.0 ,18.0],[10.0 ,22.0]],[[3.0 ,15.0],[7.0 ,19.0],[11.0 ,23.0]],[[4.0 ,16.0],[8.0 ,20.0],[12.0 ,24.0]]]*/
Nd4j.factory().setOrder('f');
INDArray test = Nd4j.linspace(1, 24, 24).reshape(new int[]{4,3,2});
NDArrayIndex oneTwo = NDArrayIndex.interval(1, 2);
NDArrayIndex twoToThree = NDArrayIndex.interval(1,3);
INDArray get = test.get(oneTwo,twoToThree);
assertTrue(Arrays.equals(new int[]{1,2,2},get.shape()));
assertEquals(Nd4j.create(new float[]{6, 10, 18, 22}, new int[]{1, 2, 2}),get);
INDArray anotherGet = Nd4j.create(new float[]{6, 7, 10, 11, 18, 19, 22, 23}, new int[]{2, 1, 2});
INDArray test2 = test.get(NDArrayIndex.interval(1,3),NDArrayIndex.interval(1,2));
assertEquals(5,test2.offset());
//offset is off: should be 5
assertTrue(Arrays.equals(new int[]{2,1,2},test2.shape()));
assertEquals(test2,anotherGet);
INDArray linear = test2.slice(0).linearView();
assertEquals(10,linear.getFloat(1),1e-1);
INDArray row = Nd4j.create(new float[]{7,11});
INDArray result = test2.slice(1);
assertEquals(row,result);
}
@Test
public void testSwapAxesFortranOrder() {
Nd4j.factory().setOrder('f');
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30).data(),new int[]{3,5,2});
INDArray slice = n.swapAxes(2,1);
INDArray assertion = Nd4j.create(new double[]{1,4,7,10,13});
INDArray test = slice.slice(0).slice(0);
assertEquals(assertion,test);
}
@Test
public void testGetIndicesVector() {
INDArray line = Nd4j.linspace(1, 4, 4);
INDArray test = Nd4j.create(new float[]{2, 3});
INDArray result = line.get(NDArrayIndex.interval(1, 3));
assertEquals(test,result);
}
@Test
public void testGetIndices2d() {
Nd4j.factory().setOrder('f');
INDArray twoByTwo = Nd4j.linspace(1, 6, 6).reshape(3,2);
INDArray firstRow = twoByTwo.getRow(0);
INDArray secondRow = twoByTwo.getRow(1);
INDArray firstAndSecondRow = twoByTwo.getRows(new int[]{1,2});
INDArray firstRowViaIndexing = twoByTwo.get(NDArrayIndex.interval(0,1));
assertEquals(firstRow,firstRowViaIndexing);
INDArray secondRowViaIndexing = twoByTwo.get(NDArrayIndex.interval(1,2));
assertEquals(secondRow,secondRowViaIndexing);
INDArray individualElement = twoByTwo.get(NDArrayIndex.interval(1,2),NDArrayIndex.interval(1,2));
individualElement.toString();
assertEquals(Nd4j.create(new float[]{5}),individualElement);
INDArray firstAndSecondRowTest = twoByTwo.get(NDArrayIndex.interval(1, 3));
assertEquals(firstAndSecondRow, firstAndSecondRowTest);
}
@Test
public void testDup() {
INDArray orig = Nd4j.linspace(1,4,4);
INDArray dup = orig.dup();
assertEquals(orig,dup);
INDArray matrix = Nd4j.create(new float[]{1,2,3,4},new int[]{2,2});
INDArray dup2 = matrix.dup();
assertEquals(matrix,dup2);
INDArray row1 = matrix.getRow(1);
INDArray dupRow = row1.dup();
assertEquals(row1,dupRow);
INDArray columnSorted = Nd4j.create(new float[]{2,1,4,3},new int[]{2,2});
INDArray dup3 = columnSorted.dup();
assertEquals(columnSorted,dup3);
}
@Test
public void testSortWithIndicesDescending() {
INDArray toSort = Nd4j.linspace(1,4,4).reshape(2,2);
//indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(),1,false);
INDArray sorted2 = Nd4j.sort(toSort.dup(),1,false);
assertEquals(sorted[1],sorted2);
INDArray shouldIndex = Nd4j.create(new float[]{1,0,1,0},new int[]{2,2});
assertEquals(shouldIndex,sorted[0]);
}
@Test
public void testSortWithIndices() {
INDArray toSort = Nd4j.linspace(1,4,4).reshape(2,2);
//indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(),1,true);
INDArray sorted2 = Nd4j.sort(toSort.dup(),1,true);
assertEquals(sorted[1],sorted2);
INDArray shouldIndex = Nd4j.create(new float[]{0,1,0,1},new int[]{2,2});
assertEquals(shouldIndex,sorted[0]);
}
@Test
public void testDimShuffle() {
INDArray n = Nd4j.linspace(1,4,4).reshape(2,2);
INDArray twoOneTwo = n.dimShuffle(new Object[]{0,'x',1},new int[]{0,1},new boolean[]{false,false});
assertTrue(Arrays.equals(new int[]{2,1,2},twoOneTwo.shape()));
INDArray reverse = n.dimShuffle(new Object[]{1,'x',0},new int[]{1,0},new boolean[]{false,false});
assertTrue(Arrays.equals(new int[]{2,1,2}, reverse.shape()));
}
@Test
public void testGetVsGetScalar() {
INDArray a = Nd4j.linspace(1, 4, 4).reshape(2,2);
float element = a.getFloat(0, 1);
double element2 = a.getDouble(0,1);
assertEquals(element,element2,1e-1);
Nd4j.factory().setOrder('f');
INDArray a2 = Nd4j.linspace(1, 4, 4).reshape(2,2);
float element23 = a2.getFloat(0, 1);
double element22 = a2.getDouble(0,1);
assertEquals(element23,element22,1e-1);
}
@Test
public void testDivide() {
INDArray two = Nd4j.create(new float[]{2, 2, 2, 2});
INDArray div = two.div(two);
assertEquals(Nd4j.ones(4),div);
INDArray half = Nd4j.create(new float[]{0.5f, 0.5f, 0.5f, 0.5f}, new int[]{2, 2});
INDArray divi = Nd4j.create(new float[]{0.3f, 0.6f, 0.9f, 0.1f}, new int[]{2, 2});
INDArray assertion = Nd4j.create(new float[]{1.6666666f, 0.8333333f, 0.5555556f, 5}, new int[]{2, 2});
INDArray result = half.div(divi);
assertEquals(assertion,result);
}
@Test
public void testSigmoid() {
INDArray n = Nd4j.create(new float[]{1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[]{0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
INDArray sigmoid = Transforms.sigmoid(n);
assertEquals(assertion,sigmoid);
}
@Test
public void testNeg() {
INDArray n = Nd4j.create(new float[]{1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[]{-1, -2, -3, -4});
INDArray neg = Transforms.neg(n);
assertEquals(assertion,neg);
}
@Test
public void testNorm2() {
INDArray n = Nd4j.create(new float[]{1, 2, 3, 4});
float assertion = 5.47722557505f;
assertEquals(assertion,n.norm2(Integer.MAX_VALUE).getFloat(0),1e-1);
INDArray row = Nd4j.create(new float[]{1,2,3,4},new int[]{2,2});
INDArray row1 = row.getRow(1);
float norm2 = row1.norm2(Integer.MAX_VALUE).getFloat(0);
float assertion2 = 5.0f;
assertEquals(assertion2,norm2,1e-1);
}
@Test
public void testDiag() {
INDArray diag2 = Nd4j.linspace(1,8,8);
INDArray diag = Nd4j.diag(diag2);
for(int i = 0; i < diag2.length(); i++) {
assertEquals(i + 1,diag.getFloat(i, i),1e-1);
}
INDArray diag3 = diag2.reshape(2,4);
INDArray diagMatrix = Nd4j.diag(diag3);
log.info("diag " + Nd4j.diag(diag3));
}
@Test
public void testCosineSim() {
INDArray vec1 = Nd4j.create(new float[]{1,2,3,4});
INDArray vec2 = Nd4j.create(new float[]{1,2,3,4});
double sim = Transforms.cosineSim(vec1,vec2);
assertEquals(1,sim,1e-1);
INDArray vec3 = Nd4j.create(new float[]{0.2f,0.3f,0.4f,0.5f});
INDArray vec4 = Nd4j.create(new float[]{0.6f,0.7f,0.8f,0.9f});
sim = Transforms.cosineSim(vec3,vec4);
assertEquals(0.98,sim,1e-1);
}
@Test
public void testScal() {
INDArray n = Nd4j.create(new float[]{1, 2, 3, 4});
float assertion = 2;
INDArray answer = Nd4j.create(new float[]{2,4,6,8});
assertEquals(answer,Nd4j.getBlasWrapper().scal(assertion,answer));
INDArray row = Nd4j.create(new float[]{1,2,3,4},new int[]{2,2});
INDArray row1 = row.getRow(1);
float assertion2 = 5.0f;
INDArray answer2 = Nd4j.create(new float[]{15,20});
assertEquals(answer2,Nd4j.getBlasWrapper().scal(assertion2,row1));
}
@Test
public void testExp() {
INDArray n = Nd4j.create(new float[]{1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[]{2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
INDArray exped = Transforms.exp(n);
assertEquals(assertion,exped);
}
@Test
public void testSlices() {
INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[]{4, 3, 2});
for(int i = 0; i < arr.slices(); i++) {
assertEquals(2, arr.slice(i).slice(1).slices());
}
}
@Test
public void testScalar() {
INDArray a = Nd4j.scalar(1.0);
assertEquals(true,a.isScalar());
INDArray n = Nd4j.create(new float[]{1.0f}, new int[]{1, 1});
assertEquals(n,a);
assertTrue(n.isScalar());
}
@Test
public void testWrap() {
int[] shape = {2,4};
INDArray d = Nd4j.linspace(1, 8, 8).reshape(shape[0],shape[1]);
INDArray n =d;
assertEquals(d.rows(),n.rows());
assertEquals(d.columns(),n.columns());
INDArray vector = Nd4j.linspace(1, 3, 3);
INDArray testVector = vector;
for(int i = 0; i < vector.length(); i++)
assertEquals(vector.getDouble(i),testVector.getDouble(i),1e-1);
assertEquals(3,testVector.length());
assertEquals(true,testVector.isVector());
assertEquals(true,Shape.shapeEquals(new int[]{3},testVector.shape()));
INDArray row12 = Nd4j.linspace(1, 2, 2).reshape(2,1);
INDArray row22 = Nd4j.linspace(3, 4, 2).reshape(1,2);
assertEquals(row12.rows(),2);
assertEquals(row12.columns(),1);
assertEquals(row22.rows(),1);
assertEquals(row22.columns(),2);
}
@Test
public void testGetRowFortran() {
Nd4j.factory().setOrder('f');
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4).data(), new int[]{2, 2});
INDArray column = Nd4j.create(new float[]{1, 3});
INDArray column2 = Nd4j.create(new float[]{2, 4});
INDArray testColumn = n.getRow(0);
INDArray testColumn1 = n.getRow(1);
assertEquals(column,testColumn);
assertEquals(column2,testColumn1);
Nd4j.factory().setOrder('c');
}
@Test
public void testGetColumnFortran() {
Nd4j.factory().setOrder('f');
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4).data(), new int[]{2, 2});
INDArray column = Nd4j.create(new float[]{1, 2});
INDArray column2 = Nd4j.create(new float[]{3, 4});
INDArray testColumn = n.getColumn(0);
INDArray testColumn1 = n.getColumn(1);
assertEquals(column,testColumn);
assertEquals(column2,testColumn1);
Nd4j.factory().setOrder('c');
}
@Test
public void testVectorInit() {
DataBuffer data = Nd4j.linspace(1, 4, 4).data();
INDArray arr = Nd4j.create(data, new int[]{4});
assertEquals(true,arr.isRowVector());
INDArray arr2 = Nd4j.create(data, new int[]{1, 4});
assertEquals(true,arr2.isRowVector());
INDArray columnVector = Nd4j.create(data, new int[]{4, 1});
assertEquals(true,columnVector.isColumnVector());
}
@Test
public void testColumns() {
INDArray arr = Nd4j.create(new int[]{3, 2});
INDArray column2 = arr.getColumn(0);
assertEquals(true,Shape.shapeEquals(new int[]{3,1}, column2.shape()));
INDArray column = Nd4j.create(new float[]{1, 2, 3}, new int[]{3});
arr.putColumn(0,column);
INDArray firstColumn = arr.getColumn(0);
assertEquals(column,firstColumn);
INDArray column1 = Nd4j.create(new float[]{4, 5, 6}, new int[]{3});
arr.putColumn(1,column1);
assertEquals(true, Shape.shapeEquals(new int[]{3,1}, arr.getColumn(1).shape()));
INDArray testRow1 = arr.getColumn(1);
assertEquals(column1,testRow1);
INDArray evenArr = Nd4j.create(new float[]{1, 2, 3, 4}, new int[]{2, 2});
INDArray put = Nd4j.create(new float[]{5, 6}, new int[]{2});
evenArr.putColumn(1,put);
INDArray testColumn = evenArr.getColumn(1);
assertEquals(put,testColumn);
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4).data(), new int[]{2, 2});
INDArray column23 = n.getColumn(0);
INDArray column12 = Nd4j.create(new float[]{1, 3}, new int[]{2});
assertEquals(column23,column12);
INDArray column0 = n.getColumn(1);
INDArray column01 = Nd4j.create(new float[]{2, 4}, new int[]{2});
assertEquals(column0,column01);
}
@Test
public void testPutRow() {
INDArray d = Nd4j.linspace(1, 4, 4).reshape(2,2);
INDArray n = d.dup();
//works fine according to matlab, let's go with it..
//reproduce with: A = reshape(linspace(1,4,4),[2 2 ]);
//A(1,2) % 1 index based
float nFirst = 2;
float dFirst = d.getFloat(0, 1);
assertEquals(nFirst,dFirst,1e-1);
assertEquals(d.data(),n.data());
assertEquals(true,Arrays.equals(new int[]{2,2},n.shape()));
INDArray newRow = Nd4j.linspace(5, 6, 2);
n.putRow(0,newRow);
d.putRow(0,newRow);
INDArray testRow = n.getRow(0);
assertEquals(newRow.length(),testRow.length());
assertEquals(true, Shape.shapeEquals(new int[]{2}, testRow.shape()));
INDArray nLast = Nd4j.create(Nd4j.linspace(1, 4, 4).data(), new int[]{2, 2});
INDArray row = nLast.getRow(1);
INDArray row1 = Nd4j.create(new float[]{3, 4}, new int[]{2});
assertEquals(row,row1);
INDArray arr = Nd4j.create(new int[]{3, 2});
INDArray evenRow = Nd4j.create(new float[]{1, 2}, new int[]{2});
arr.putRow(0,evenRow);
INDArray firstRow = arr.getRow(0);
assertEquals(true, Shape.shapeEquals(new int[]{2},firstRow.shape()));
INDArray testRowEven = arr.getRow(0);
assertEquals(evenRow,testRowEven);
INDArray row12 = Nd4j.create(new float[]{5, 6}, new int[]{2});
arr.putRow(1,row12);
assertEquals(true, Shape.shapeEquals(new int[]{2}, arr.getRow(0).shape()));
INDArray testRow1 = arr.getRow(1);
assertEquals(row12,testRow1);
INDArray multiSliceTest = Nd4j.create(Nd4j.linspace(1, 16, 16).data(), new int[]{4, 2, 2});
INDArray test = Nd4j.create(new float[]{7, 8}, new int[]{2});
INDArray test2 = Nd4j.create(new float[]{9, 10}, new int[]{2});
assertEquals(test,multiSliceTest.slice(1).getRow(1));
assertEquals(test2,multiSliceTest.slice(1).getRow(2));
}
@Test
public void testOrdering() {
//c ordering first
Nd4j.factory().setOrder('c');
Nd4j.factory().setDType("float");
INDArray data = Nd4j.create(new float[]{1, 2, 3, 4}, new int[]{2, 2});
assertEquals(2.0,(float) data.getScalar(0,1).element(),1e-1);
Nd4j.factory().setOrder('f');
INDArray data2 = Nd4j.create(new float[]{1, 2, 3, 4}, new int[]{2, 2});
assertNotEquals(data2.getDouble(0,1),data.getDouble(0,1),1e-1);
Nd4j.factory().setOrder('c');
}
@Test
public void testSum() {
INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[]{2, 2, 2});
INDArray test = Nd4j.create(new float[]{3, 7, 11, 15}, new int[]{2, 2});
INDArray sum = n.sum(n.shape().length - 1);
assertEquals(test,sum);
}
@Test
public void testMmulF() {
Nd4j.factory().setOrder('f');
DataBuffer data = Nd4j.linspace(1, 10, 10).data();
INDArray n = Nd4j.create(data, new int[]{10});
INDArray transposed = n.transpose();
assertEquals(true,n.isRowVector());
assertEquals(true,transposed.isColumnVector());
INDArray d = Nd4j.create(n.data().dup(), new int[]{n.rows(), n.columns()});
INDArray innerProduct = n.mmul(transposed);
INDArray scalar = Nd4j.scalar(385);
assertEquals(scalar,innerProduct);
}
@Test
public void testSum2() {
INDArray test = Nd4j.create(new float[]{1,2,3,4},new int[]{2,2});
INDArray sum = test.sum(1);
log.info("Sum" + sum);
}
@Test
public void testMmul() {
Nd4j.factory().setOrder('c');
Nd4j.dtype = "double";
DataBuffer data = Nd4j.linspace(1, 10, 10).data();
INDArray n = Nd4j.create(data, new int[]{10});
INDArray transposed = n.transpose();
assertEquals(true,n.isRowVector());
assertEquals(true,transposed.isColumnVector());
INDArray d = Nd4j.create(n.rows(), n.columns());
d.setData(n.data());
INDArray innerProduct = n.mmul(transposed);
INDArray scalar = Nd4j.scalar(385);
assertEquals(scalar,innerProduct);
INDArray outerProduct = transposed.mmul(n);
assertEquals(true, Shape.shapeEquals(new int[]{10,10},outerProduct.shape()));
INDArray testMatrix = Nd4j.create(data, new int[]{5, 2});
INDArray row1 = testMatrix.getRow(0).transpose();
INDArray row2 = testMatrix.getRow(1);
INDArray row12 = Nd4j.linspace(1, 2, 2).reshape(2,1);
INDArray row22 = Nd4j.linspace(3, 4, 2).reshape(1,2);
INDArray row122 = row12;
INDArray row222 = row22;
INDArray rowResult2 = row122.mmul(row222);
INDArray d3 = Nd4j.create(new double[]{1, 2}).reshape(2,1);
INDArray d4 = Nd4j.create(new double[]{3, 4});
INDArray resultNDArray = d3.mmul(d4);
INDArray result = Nd4j.create(new double[][]{{3, 4}, {6, 8}});
assertEquals(result,resultNDArray);
INDArray three = Nd4j.create(new double[]{3, 4}, new int[]{2});
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30).data(), new int[]{3, 5, 2});
INDArray sliceRow = test.slice(0).getRow(1);
assertEquals(three,sliceRow);
INDArray twoSix = Nd4j.create(new double[]{2, 6}, new int[]{2, 1});
INDArray threeTwoSix = three.mmul(twoSix);
INDArray sliceRowTwoSix = sliceRow.mmul(twoSix);
assertEquals(threeTwoSix,sliceRowTwoSix);
INDArray vectorVector = Nd4j.create(new double[]{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 0, 6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 0, 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91, 98, 105, 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99, 108, 117, 126, 135, 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 0, 11, 22, 33, 44, 55, 66, 77, 88, 99, 110, 121, 132, 143, 154, 165, 0, 12, 24, 36, 48, 60, 72, 84, 96, 108, 120, 132, 144, 156, 168, 180, 0, 13, 26, 39, 52, 65, 78, 91, 104, 117, 130, 143, 156, 169, 182, 195, 0, 14, 28, 42, 56, 70, 84, 98, 112, 126, 140, 154, 168, 182, 196, 210, 0, 15, 30, 45, 60, 75, 90, 105, 120, 135, 150, 165, 180, 195, 210, 225
}, new int[]{16, 16});
INDArray n1 = Nd4j.create(Nd4j.linspace(0, 15, 16).data(), new int[]{16});
INDArray k1 = n1.transpose();
INDArray testVectorVector = k1.mmul(n1);
assertEquals(vectorVector,testVectorVector);
}
@Test
public void testRowsColumns() {
DataBuffer data = Nd4j.linspace(1, 6, 6).data();
INDArray rows = Nd4j.create(data, new int[]{2, 3});
assertEquals(2,rows.rows());
assertEquals(3,rows.columns());
INDArray columnVector = Nd4j.create(data, new int[]{6, 1});
assertEquals(6,columnVector.rows());
assertEquals(1,columnVector.columns());
INDArray rowVector = Nd4j.create(data, new int[]{6});
assertEquals(1,rowVector.rows());
assertEquals(6,rowVector.columns());
}
@Test
public void testTranspose() {
INDArray n = Nd4j.create(Nd4j.ones(100).data(), new int[]{5, 5, 4});
INDArray transpose = n.transpose();
assertEquals(n.length(),transpose.length());
assertEquals(true,Arrays.equals(new int[]{4,5,5},transpose.shape()));
INDArray rowVector = Nd4j.linspace(1, 10, 10);
assertTrue(rowVector.isRowVector());
INDArray columnVector = rowVector.transpose();
assertTrue(columnVector.isColumnVector());
INDArray linspaced = Nd4j.linspace(1, 4, 4).reshape(2,2);
INDArray transposed = Nd4j.create(new float[]{1, 3, 2, 4}, new int[]{2, 2});
assertEquals(transposed,linspaced.transpose());
Nd4j.factory().setOrder('f');
linspaced = Nd4j.linspace(1, 4, 4).reshape(2,2);
//fortran ordered
INDArray transposed2 = Nd4j.create(new float[]{1, 3, 2, 4}, new int[]{2, 2});
transposed = linspaced.transpose();
assertEquals(transposed,transposed2);
Nd4j.factory().setOrder('c');
}
@Test
public void testPutSlice() {
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new int[]{3, 3, 3});
INDArray newSlice = Nd4j.zeros(3, 3);
n.putSlice(0,newSlice);
assertEquals(newSlice,n.slice(0));
}
@Test
public void testPermute() {
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20).data(), new int[]{5, 4});
INDArray transpose = n.transpose();
INDArray permute = n.permute(new int[]{1,0});
assertEquals(permute,transpose);
assertEquals(transpose.length(),permute.length(),1e-1);
INDArray toPermute = Nd4j.create(Nd4j.linspace(0, 7, 8).data(), new int[]{2, 2, 2});
INDArray permuted = toPermute.permute(new int[]{2,1,0});
INDArray assertion = Nd4j.create(new float[]{0, 4, 2, 6, 1, 5, 3, 7}, new int[]{2, 2, 2});
assertEquals(permuted,assertion);
}
@Test
public void testSlice() {
assertEquals(8,n.length());
assertEquals(true,Arrays.equals(new int[]{2,2,2},n.shape()));
INDArray slice = n.slice(0);
assertEquals(true, Arrays.equals(new int[]{2, 2}, slice.shape()));
INDArray slice1 = n.slice(1);
assertNotEquals(slice,slice1);
INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[]{4, 3, 2});
INDArray slice0 = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6}, new int[]{3, 2});
INDArray slice2 = Nd4j.create(new float[]{7, 8, 9, 10, 11, 12}, new int[]{3, 2});
INDArray testSlice0 = arr.slice(0);
INDArray testSlice1 = arr.slice(1);
assertEquals(slice0,testSlice0);
assertEquals(slice2,testSlice1);
}
@Test
public void testSwapAxes() {
INDArray n = Nd4j.create(Nd4j.linspace(0, 7, 8).data(), new int[]{2, 2, 2});
INDArray assertion = n.permute(new int[]{2,1,0});
INDArray validate = Nd4j.create(new float[]{0, 4, 2, 6, 1, 5, 3, 7}, new int[]{2, 2, 2});
assertEquals(validate,assertion);
INDArray thirty = Nd4j.linspace(1,30,30).reshape(3,5,2);
INDArray swapped = thirty.swapAxes(2,1);
INDArray slice = swapped.slice(0).slice(0);
INDArray assertion2 = Nd4j.create(new double[]{1,3,5,7,9});
assertEquals(assertion2,slice);
}
@Test
public void testLinearIndex() {
INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[]{8});
for(int i = 0; i < n.length(); i++) {
int linearIndex = n.linearIndex(i);
assertEquals(i ,linearIndex);
double d = n.getDouble(i);
assertEquals(i + 1,d,1e-1);
}
}
@Test
public void testSliceConstructor() {
List testList = new ArrayList<>();
for(int i = 0; i < 5; i++)
testList.add(Nd4j.scalar(i + 1));
INDArray test = Nd4j.create(testList, new int[]{testList.size()});
INDArray expected = Nd4j.create(new float[]{1, 2, 3, 4, 5}, new int[]{5});
assertEquals(expected,test);
}
@Test
public void testVectorDimension() {
INDArray test = Nd4j.create(Nd4j.linspace(1, 4, 4).data(), new int[]{2, 2});
final AtomicInteger count = new AtomicInteger(0);
//row wise
test.iterateOverDimension(1,new SliceOp() {
@Override
public void operate(DimensionSlice nd) {
INDArray test = (INDArray) nd.getResult();
if(count.get() == 0) {
INDArray firstDimension = Nd4j.create(new float[]{1, 2}, new int[]{2});
assertEquals(firstDimension,test);
}
else {
INDArray firstDimension = Nd4j.create(new float[]{3, 4}, new int[]{2});
assertEquals(firstDimension,test);
}
count.incrementAndGet();
}
/**
* Operates on an ndarray slice
*
* @param nd the result to operate on
*/
@Override
public void operate(INDArray nd) {
INDArray test = nd;
if(count.get() == 0) {
INDArray firstDimension = Nd4j.create(new float[]{1, 2}, new int[]{2});
assertEquals(firstDimension,test);
}
else {
INDArray firstDimension = Nd4j.create(new float[]{3, 4}, new int[]{2});
assertEquals(firstDimension,test);
}
count.incrementAndGet();
}
},false);
count.set(0);
//columnwise
test.iterateOverDimension(0,new SliceOp() {
@Override
public void operate(DimensionSlice nd) {
log.info("Operator " + nd);
INDArray test = (INDArray) nd.getResult();
if(count.get() == 0) {
INDArray firstDimension = Nd4j.create(new float[]{1, 3}, new int[]{2});
assertEquals(firstDimension,test);
}
else {
INDArray firstDimension = Nd4j.create(new float[]{2, 4}, new int[]{2});
assertEquals(firstDimension,test);
}
count.incrementAndGet();
}
/**
* Operates on an ndarray slice
*
* @param nd the result to operate on
*/
@Override
public void operate(INDArray nd) {
log.info("Operator " + nd);
INDArray test = nd;
if(count.get() == 0) {
INDArray firstDimension = Nd4j.create(new float[]{1, 3}, new int[]{2});
assertEquals(firstDimension,test);
}
else {
INDArray firstDimension = Nd4j.create(new float[]{2, 4}, new int[]{2});
assertEquals(firstDimension,test);
}
count.incrementAndGet();
}
},false);
}
@Test
public void testDimension() {
INDArray test = Nd4j.create(Nd4j.linspace(1, 4, 4).data(), new int[]{2, 2});
//row
INDArray slice0 = test.slice(0,1);
INDArray slice02 = test.slice(1,1);
INDArray assertSlice0 = Nd4j.create(new float[]{1, 2});
INDArray assertSlice02 = Nd4j.create(new float[]{3, 4});
assertEquals(assertSlice0,slice0);
assertEquals(assertSlice02,slice02);
//column
INDArray assertSlice1 = Nd4j.create(new float[]{1, 3});
INDArray assertSlice12 = Nd4j.create(new float[]{2, 4});
INDArray slice1 = test.slice(0,0);
INDArray slice12 = test.slice(1,0);
assertEquals(assertSlice1,slice1);
assertEquals(assertSlice12,slice12);
INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[]{4, 3, 2});
INDArray firstSliceFirstDimension = arr.slice(0,1);
INDArray secondSliceFirstDimension = arr.slice(1,1);
INDArray firstSliceFirstDimensionAssert = Nd4j.create(new float[]{1, 2, 7, 8, 13, 14, 19, 20});
INDArray secondSliceFirstDimension2Test = firstSliceFirstDimensionAssert.add(1);
assertEquals(secondSliceFirstDimension,secondSliceFirstDimension);
}
@Test
public void testAppendBias() {
INDArray rand = Nd4j.linspace(1,25,25).transpose();
INDArray test = Nd4j.appendBias(rand);
INDArray assertion = Nd4j.toFlattened(rand,Nd4j.scalar(1));
assertEquals(assertion,test);
}
@Test
public void testReshape() {
INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[]{4, 3, 2});
INDArray reshaped = arr.reshape(new int[]{2,3,4});
assertEquals(arr.length(),reshaped.length());
assertEquals(true,Arrays.equals(new int[]{4,3,2},arr.shape()));
assertEquals(true,Arrays.equals(new int[]{2,3,4},reshaped.shape()));
INDArray n2 = Nd4j.create(Nd4j.linspace(1, 30, 30).data(), new int[]{3, 5, 2});
INDArray swapped = n2.swapAxes(n2.shape().length - 1,1);
INDArray firstSlice2 = swapped.slice(0).slice(0);
INDArray oneThreeFiveSevenNine = Nd4j.create(new float[]{1, 3, 5, 7, 9});
assertEquals(firstSlice2,oneThreeFiveSevenNine);
INDArray raveled = oneThreeFiveSevenNine.reshape(5,1);
INDArray raveledOneThreeFiveSevenNine = oneThreeFiveSevenNine.reshape(5,1);
assertEquals(raveled,raveledOneThreeFiveSevenNine);
INDArray firstSlice3 = swapped.slice(0).slice(1);
INDArray twoFourSixEightTen = Nd4j.create(new float[]{2, 4, 6, 8, 10});
assertEquals(firstSlice2,oneThreeFiveSevenNine);
INDArray raveled2 = twoFourSixEightTen.reshape(5,1);
INDArray raveled3 = firstSlice3.reshape(5,1);
assertEquals(raveled2,raveled3);
}
@Test
public void testDot() {
INDArray vec1 = Nd4j.create(new float[]{1,2,3,4});
INDArray vec2 = Nd4j.create(new float[]{1,2,3,4});
assertEquals(30,Nd4j.getBlasWrapper().dot(vec1,vec2),1e-1);
INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2);
INDArray row = matrix.getRow(1);
assertEquals(25,Nd4j.getBlasWrapper().dot(row,row),1e-1);
}
@Test
public void reduceTest() {
INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[]{4, 3, 2});
INDArray reduced = arr.reduce(Ops.DimensionOp.MAX,1);
log.info("Reduced " + reduced);
reduced = arr.reduce(Ops.DimensionOp.MAX,1);
log.info("Reduced " + reduced);
reduced = arr.reduce(Ops.DimensionOp.MAX,2);
log.info("Reduced " + reduced);
}
@Test
public void testIdentity() {
INDArray eye = Nd4j.eye(5);
assertTrue(Arrays.equals(new int[]{5,5},eye.shape()));
Nd4j.factory().setOrder('f');
eye = Nd4j.eye(5);
assertTrue(Arrays.equals(new int[]{5,5},eye.shape()));
}
@Test
public void testColumnVectorOpsFortran() {
Nd4j.factory().setOrder('f');
INDArray twoByTwo = Nd4j.create(new float[]{1, 2, 3, 4}, new int[]{2, 2});
INDArray toAdd = Nd4j.create(new float[]{1, 2}, new int[]{2, 1});
twoByTwo.addiColumnVector(toAdd);
INDArray assertion = Nd4j.create(new float[]{2,4,4, 6}, new int[]{2, 2});
assertEquals(assertion,twoByTwo);
}
@Test
public void testGetNonContiguous() {
INDArray create = Nd4j.linspace(1,6,6).reshape(2,3);
NDArrayIndex[] indices = new NDArrayIndex[2];
indices[0] = NDArrayIndex.interval(0,1);
indices[1] = new NDArrayIndex(new int[]{0,2});
INDArray assertion = Nd4j.create(new double[]{1,3});
assertEquals(create.get(indices),assertion);
INDArray assertion2 = Nd4j.create(new double[]{4, 7});
create.put(indices,Nd4j.create(new double[]{4,7}));
assertEquals(assertion2,create.get(indices));
INDArray multiRowAssign = Nd4j.create(new double[]{5,6,7,8},new int[]{2,2});
NDArrayIndex[] index2 = new NDArrayIndex[]{
NDArrayIndex.interval(0,2),
new NDArrayIndex(0,2)
};
create.put(index2,multiRowAssign);
INDArray get = create.get(index2);
assertEquals(multiRowAssign,get);
}
@Test
public void testMeans() {
INDArray a = Nd4j.linspace(1, 4, 4).reshape(2,2);
assertEquals(Nd4j.create(new float[]{2, 3}),a.mean(0));
assertEquals(Nd4j.create(new float[]{1.5f, 3.5f}),a.mean(1));
assertEquals(2.5, a.mean(Integer.MAX_VALUE).getDouble(0),1e-1);
}
@Test
public void testSums() {
INDArray a = Nd4j.linspace(1, 4, 4).reshape(2,2);
assertEquals(Nd4j.create(new float[]{4, 6}),a.sum(0));
assertEquals(Nd4j.create(new float[]{3, 7}),a.sum(1));
assertEquals(10,a.sum(Integer.MAX_VALUE).getDouble(0),1e-1);
}
@Test
public void testCumSum() {
INDArray n = Nd4j.create(new float[]{1, 2, 3, 4}, new int[]{4});
INDArray cumSumAnswer = Nd4j.create(new float[]{1, 3, 6, 10}, new int[]{4});
INDArray cumSumTest = n.cumsum(0);
assertEquals(cumSumAnswer,cumSumTest);
INDArray n2 = Nd4j.linspace(1, 24, 24).reshape(new int[]{4,3,2});
INDArray cumSumCorrect2 = Nd4j.create(new double[]{1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 66.0, 78.0, 91.0, 105.0, 120.0, 136.0, 153.0, 171.0, 190.0, 210.0, 231.0, 253.0, 276.0, 300.0}, new int[]{24});
INDArray cumSumTest2 = n2.cumsum(n2.shape().length - 1);
assertEquals(cumSumCorrect2,cumSumTest2);
INDArray axis0assertion = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 8, 10, 12, 14, 16, 18, 21, 24, 27, 30, 33, 36, 40, 44, 48, 52, 56, 60}, n2.shape());
INDArray axis0Test = n2.cumsum(0);
assertEquals(axis0assertion,axis0Test);
}
@Test
public void testRSubi() {
INDArray n2 = Nd4j.ones(2);
INDArray n2Assertion = Nd4j.zeros(2);
INDArray nRsubi = n2.rsubi(1);
assertEquals(n2Assertion,nRsubi);
}
@Test
public void testConcat() {
INDArray A = Nd4j.linspace(1,8,8).reshape(2,2,2);
INDArray B = Nd4j.linspace(1,12,12).reshape(3,2,2);
INDArray concat = Nd4j.concat(0,A,B);
assertTrue(Arrays.equals(new int[]{5,2,2},concat.shape()));
}
@Test
public void testConcatHorizontally() {
INDArray rowVector = Nd4j.ones(5);
INDArray other = Nd4j.ones(5);
INDArray concat = Nd4j.hstack(other,rowVector);
assertEquals(rowVector.rows(),concat.rows());
assertEquals(rowVector.columns() * 2,concat.columns());
}
@Test
public void testConcatVertically() {
INDArray rowVector = Nd4j.ones(5);
INDArray other = Nd4j.ones(5);
INDArray concat = Nd4j.vstack(other,rowVector);
assertEquals(rowVector.rows() * 2,concat.rows());
assertEquals(rowVector.columns(),concat.columns());
}
@Test
public void testRDivi() {
INDArray n2 = Nd4j.valueArrayOf(new int[]{2}, 4);
INDArray n2Assertion = Nd4j.valueArrayOf(new int[]{2}, 0.5);
INDArray nRsubi = n2.rdivi(2);
assertEquals(n2Assertion,nRsubi);
}
@Test
public void testVectorAlongDimension() {
INDArray arr = Nd4j.linspace(1, 24, 24).reshape(new int[]{4,3,2});
INDArray assertion = Nd4j.create(new float[]{1, 2}, new int[]{2});
assertEquals(Nd4j.create(new float[]{3, 4}, new int[]{2}),arr.vectorAlongDimension(1,2));
assertEquals(assertion,arr.vectorAlongDimension(0,2));
assertEquals(arr.vectorAlongDimension(0,1), Nd4j.create(new float[]{1, 3, 5}));
INDArray testColumn2Assertion = Nd4j.create(new float[]{7, 9, 11});
INDArray testColumn2 = arr.vectorAlongDimension(1,1);
assertEquals(testColumn2Assertion,testColumn2);
INDArray testColumn3Assertion = Nd4j.create(new float[]{13, 15, 17});
INDArray testColumn3 = arr.vectorAlongDimension(2,1);
assertEquals(testColumn3Assertion,testColumn3);
INDArray v1= Nd4j.linspace(1, 4, 4).reshape(new int[]{2,2});
INDArray testColumnV1 = v1.vectorAlongDimension(0,0);
INDArray testColumnV1Assertion = Nd4j.create(new float[]{1, 3});
assertEquals(testColumnV1Assertion,testColumnV1);
INDArray testRowV1 = v1.vectorAlongDimension(1,0);
INDArray testRowV1Assertion = Nd4j.create(new float[]{2, 4});
assertEquals(testRowV1Assertion,testRowV1);
INDArray lastAxis = arr.vectorAlongDimension(0,2);
assertEquals(assertion,lastAxis);
}
@Test
public void testSquareMatrix() {
INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[]{2, 2, 2});
INDArray eightFirstTest = n.vectorAlongDimension(0,2);
INDArray eightFirstAssertion = Nd4j.create(new float[]{1, 2}, new int[]{2});
assertEquals(eightFirstAssertion,eightFirstTest);
INDArray eightFirstTestSecond = n.vectorAlongDimension(1,2);
INDArray eightFirstTestSecondAssertion = Nd4j.create(new float[]{3, 4});
assertEquals(eightFirstTestSecondAssertion,eightFirstTestSecond);
}
@Test
public void testNumVectorsAlongDimension() {
INDArray arr = Nd4j.linspace(1, 24, 24).reshape(new int[]{4,3,2});
assertEquals(12,arr.vectorsAlongDimension(2));
}
@Test
public void testGetScalar() {
INDArray n = Nd4j.create(new float[]{1, 2, 3, 4}, new int[]{4});
assertTrue(n.isVector());
for(int i = 0; i < n.length(); i++) {
INDArray scalar = Nd4j.scalar((float) i + 1);
assertEquals(scalar,n.getScalar(i));
}
}
@Test
public void testGetScalarFortran() {
Nd4j.factory().setOrder('f');
n = Nd4j.create(new float[]{1, 2, 3, 4}, new int[]{4});
for(int i = 0; i < n.length(); i++) {
INDArray scalar = Nd4j.scalar((float) i + 1);
assertEquals(scalar,n.getScalar(i));
}
INDArray twoByTwo = Nd4j.create(new float[][]{{1, 2}, {3, 4}});
INDArray column = twoByTwo.getColumn(0);
assertEquals(Nd4j.create(new float[]{1, 3}),column);
assertEquals(1,column.getFloat(0),1e-1);
assertEquals(3,column.getFloat(1),1e-1);
assertEquals(Nd4j.scalar(1),column.getScalar(0));
assertEquals(Nd4j.scalar(3),column.getScalar(1));
}
@Test
public void testGetMulti() {
assertEquals(8,n.length());
assertEquals(true,Arrays.equals(ArrayUtil.of(2, 2, 2),n.shape()));
double val = n.getDouble(new int[]{1,1,1});
assertEquals(8.0,val,1e-6);
}
@Test
public void testGetRowOrdering() {
INDArray row1 = Nd4j.linspace(1, 4, 4).reshape(2,2);
Nd4j.factory().setOrder('f');
INDArray row1Fortran = Nd4j.linspace(1, 4, 4).reshape(2,2);
assertNotEquals(row1.getFloat(0, 1),row1Fortran.getFloat(0, 1),1e-1);
Nd4j.factory().setOrder('c');
}
@Test
public void testBroadCast() {
INDArray n = Nd4j.linspace(1,4,4);
INDArray broadCasted = n.broadcast(new int[]{5,4});
for(int i = 0; i < broadCasted.rows(); i++) {
assertEquals(n,broadCasted.getRow(i));
}
INDArray broadCast2 = broadCasted.getRow(0).broadcast(5,4);
assertEquals(broadCasted,broadCast2);
INDArray columnBroadcast = n.transpose().broadcast(4,5);
for(int i = 0; i < columnBroadcast.columns(); i++) {
assertEquals(columnBroadcast.getColumn(i),n.transpose());
}
}
@Test
public void testPutRowGetRowOrdering() {
INDArray row1 = Nd4j.linspace(1, 4, 4).reshape(2,2);
INDArray put = Nd4j.create(new float[]{5, 6});
row1.putRow(1,put);
Nd4j.factory().setOrder('f');
INDArray row1Fortran = Nd4j.linspace(1, 4, 4).reshape(2,2);
INDArray putFortran = Nd4j.create(new float[]{5, 6});
row1Fortran.putRow(1,putFortran);
assertNotEquals(row1,row1Fortran);
INDArray row1CTest = row1.getRow(1);
INDArray row1FortranTest = row1Fortran.getRow(1);
assertEquals(row1CTest,row1FortranTest);
Nd4j.factory().setOrder('c');
}
@Test
public void testPutRowFortran() {
INDArray row1 = Nd4j.linspace(1, 4, 4).reshape(2,2);
INDArray put = Nd4j.create(new float[]{5, 6});
row1.putRow(1,put);
Nd4j.factory().setOrder('f');
INDArray row1Fortran = Nd4j.create(new float[][]{{1, 2}, {3, 4}});
INDArray putFortran = Nd4j.create(new float[]{5, 6});
row1Fortran.putRow(1,putFortran);
assertEquals(row1,row1Fortran);
Nd4j.factory().setOrder('c');
}
@Test
public void testElementWiseOps() {
INDArray n1 = Nd4j.scalar(1);
INDArray n2 = Nd4j.scalar(2);
assertEquals(Nd4j.scalar(3),n1.add(n2));
assertFalse(n1.add(n2).equals(n1));
INDArray n3 = Nd4j.scalar(3);
INDArray n4 = Nd4j.scalar(4);
INDArray subbed = n4.sub(n3);
INDArray mulled = n4.mul(n3);
INDArray div = n4.div(n3);
assertFalse(subbed.equals(n4));
assertFalse(mulled.equals(n4));
assertEquals(Nd4j.scalar(1),subbed);
assertEquals(Nd4j.scalar(12),mulled);
assertEquals(Nd4j.scalar(1.333333333333333333333),div);
}
@Test
public void testSlicing() {
INDArray arr = n.slice(1, 1);
// assertEquals(1,arr.shape().length());
INDArray n2 = Nd4j.create(Nd4j.linspace(1, 16, 16).data(), new int[]{2, 2, 2, 2});
log.info("N2 shape " + n2.slice(1,1).slice(1));
}
@Test
public void testEndsForSlices() {
INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[]{4, 3, 2});
int[] endsForSlices = arr.endsForSlices();
assertEquals(true,Arrays.equals(new int[]{5,11,17,23},endsForSlices));
}
@Test
public void testFlatten() {
INDArray arr = Nd4j.create(Nd4j.linspace(1, 4, 4).data(), new int[]{2, 2});
INDArray flattened = arr.ravel();
assertEquals(arr.length(),flattened.length());
assertEquals(true,Shape.shapeEquals(new int[]{1, arr.length()}, flattened.shape()));
for(int i = 0; i < arr.length(); i++) {
assertEquals(i + 1, flattened.getFloat(i),1e-1);
}
assertTrue(flattened.isVector());
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new int[]{3, 3, 3});
INDArray nFlattened = n.ravel();
assertTrue(nFlattened.isVector());
INDArray n1 = Nd4j.linspace(1, 24, 24);
assertEquals(n1, Nd4j.linspace(1, 24, 24).reshape(new int[]{4,3,2}).ravel());
}
@Test
public void testVectorDimensionMulti() {
INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[]{4, 3, 2});
final AtomicInteger count = new AtomicInteger(0);
arr.iterateOverDimension(arr.shape().length - 1,new SliceOp() {
@Override
public void operate(DimensionSlice nd) {
INDArray test =(INDArray) nd.getResult();
if(count.get() == 0) {
INDArray answer = Nd4j.create(new float[]{1, 7, 13, 19}, new int[]{4});
assertEquals(answer,test);
}
else if(count.get() == 1) {
INDArray answer = Nd4j.create(new float[]{2, 8, 14, 20}, new int[]{4});
assertEquals(answer,test);
}
else if(count.get() == 2) {
INDArray answer = Nd4j.create(new float[]{3, 9, 15, 21}, new int[]{4});
assertEquals(answer,test);
}
else if(count.get() == 3) {
INDArray answer = Nd4j.create(new float[]{4, 10, 16, 22}, new int[]{4});
assertEquals(answer,test);
}
else if(count.get() == 4) {
INDArray answer = Nd4j.create(new float[]{5, 11, 17, 23}, new int[]{4});
assertEquals(answer,test);
}
else if(count.get() == 5) {
INDArray answer = Nd4j.create(new float[]{6, 12, 18, 24}, new int[]{4});
assertEquals(answer,test);
}
count.incrementAndGet();
}
/**
* Operates on an ndarray slice
*
* @param nd the result to operate on
*/
@Override
public void operate(INDArray nd) {
INDArray test = nd;
if(count.get() == 0) {
INDArray answer = Nd4j.create(new float[]{1, 2}, new int[]{2});
assertEquals(answer,test);
}
else if(count.get() == 1) {
INDArray answer = Nd4j.create(new float[]{3, 4}, new int[]{2});
assertEquals(answer,test);
}
else if(count.get() == 2) {
INDArray answer = Nd4j.create(new float[]{5, 6}, new int[]{2});
assertEquals(answer,test);
}
else if(count.get() == 3) {
INDArray answer = Nd4j.create(new float[]{7, 8}, new int[]{2});
assertEquals(answer,test);
}
else if(count.get() == 4) {
INDArray answer = Nd4j.create(new float[]{9, 10}, new int[]{2});
assertEquals(answer,test);
}
else if(count.get() == 5) {
INDArray answer = Nd4j.create(new float[]{11, 12}, new int[]{2});
assertEquals(answer,test);
}
count.incrementAndGet();
}
},false);
}
}