Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
package org.nd4j.linalg.api.ndarray;
import com.google.common.base.Function;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DoubleBuffer;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.dimensionfunctions.DimensionFunctions;
import org.nd4j.linalg.factory.NDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.Indices;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.reduceops.Ops;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import static org.nd4j.linalg.util.ArrayUtil.*;
/**
* NDArray: (think numpy)
*
* A few things of note.
*
* An NDArray can have any number of dimensions.
*
* An NDArray is accessed via strides.
*
* Strides are how to index over
* a contiguous block of data.
*
* This block of data has 2 orders(as of right now):
* fortran and c
*
*
*
* @author Adam Gibson
*/
public abstract class BaseNDArray implements INDArray {
protected int[] shape;
protected int[] stride;
protected int offset = 0;
protected char ordering;
protected DataBuffer data;
protected int rows,columns;
protected int length;
protected INDArray linearView;
public BaseNDArray() {}
public BaseNDArray(DataBuffer buffer) {
this.data = buffer;
initShape(new int[]{1,buffer.length()});
}
public BaseNDArray(DataBuffer buffer,int[] shape,int[] stride,int offset,char ordering) {
this.data = buffer;
if(ArrayUtil.prod(shape) > buffer.length())
throw new IllegalArgumentException("Shape must be <= buffer length");
this.stride = stride;
this.offset = offset;
this.ordering = ordering;
initShape(shape);
}
public BaseNDArray(double[][] data) {
this(data.length, data[0].length);
for (int r = 0; r < rows; r++) {
assert (data[r].length == columns);
}
this.data = Nd4j.createBuffer(length);
for (int r = 0; r < rows; r++) {
for (int c = 0; c < columns; c++) {
putScalar(new int[]{r, c}, data[r][c]);
}
}
}
/**
* Create this ndarray with the given data and shape and 0 offset
* @param data the data to use
* @param shape the shape of the ndarray
*/
public BaseNDArray(float[] data, int[] shape, char ordering) {
this(data,shape,0,ordering);
}
/**
*
* @param data the data to use
* @param shape the shape of the ndarray
* @param offset the desired offset
* @param ordering the ordering of the ndarray
*/
public BaseNDArray(float[] data, int[] shape, int offset, char ordering) {
this(data,shape,ordering == NDArrayFactory.C ? calcStrides(shape) : calcStridesFortran(shape),offset);
}
/**
* Construct an ndarray of the specified shape
* with an empty data array
* @param shape the shape of the ndarray
* @param stride the stride of the ndarray
* @param offset the desired offset
* @param ordering the ordering of the ndarray
*/
public BaseNDArray(int[] shape, int[] stride, int offset, char ordering) {
this(Nd4j.createBuffer(ArrayUtil.prod(shape)),shape,stride,offset,ordering);
}
/**
* Create the ndarray with
* the specified shape and stride and an offset of 0
* @param shape the shape of the ndarray
* @param stride the stride of the ndarray
* @param ordering the ordering of the ndarray
*/
public BaseNDArray(int[] shape, int[] stride, char ordering){
this(shape,stride,0,ordering);
}
public BaseNDArray(int[] shape, int offset, char ordering) {
this(shape, Nd4j.getStrides(shape, ordering),offset,ordering);
}
public BaseNDArray(int[] shape) {
this(shape,0, Nd4j.order());
}
/**
* Creates a new n times mDoubleMatrix.
*
* @param newRows the number of rows (n) of the new matrix.
* @param newColumns the number of columns (m) of the new matrix.
*/
public BaseNDArray(int newRows, int newColumns, char ordering) {
this.ordering = ordering;
initShape(new int[]{newRows,newColumns});
}
/**
* Create an ndarray from the specified slices.
* This will go through and merge all of the
* data from each slice in to one ndarray
* which will then take the specified shape
* @param slices the slices to merge
* @param shape the shape of the ndarray
*/
public BaseNDArray(List slices, int[] shape, char ordering) {
this(slices,shape,Nd4j.getStrides(shape),ordering);
}
/**
* Create an ndarray from the specified slices.
* This will go through and merge all of the
* data from each slice in to one ndarray
* which will then take the specified shape
* @param slices the slices to merge
* @param shape the shape of the ndarray
*/
public BaseNDArray(List slices, int[] shape, int[] stride, char ordering) {
DataBuffer ret = slices.get(0).data().dataType().equals(DataBuffer.FLOAT) ?
Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) :
Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]);
this.stride = stride;
this.ordering = ordering;
this.data = ret;
initShape(shape);
for(int i = 0; i < slices(); i++) {
putSlice(i,slices.get(i));
}
}
public BaseNDArray(float[] data, int[] shape, int[] stride, char ordering) {
this(data,shape,stride,0,ordering);
}
public BaseNDArray(float[] data, int[] shape, int[] stride, int offset, char ordering) {
this.offset = offset;
this.stride = stride;
this.ordering = ordering;
initShape(shape);
if(data != null && data.length > 0) {
this.data = Nd4j.createBuffer(data);
if(offset >= data.length)
throw new IllegalArgumentException("invalid offset: must be < data.length");
}
}
public BaseNDArray(DataBuffer data, int[] shape, int[] stride, int offset) {
this.data = data;
this.stride = stride;
this.offset = offset;
this.ordering = Nd4j.order();
initShape(shape);
}
public BaseNDArray(DataBuffer data, int[] shape) {
this(data,shape,Nd4j.getStrides(shape),0,Nd4j.order());
}
public BaseNDArray(DataBuffer buffer, int[] shape, int offset) {
this(buffer,shape,Nd4j.getStrides(shape),offset);
}
public BaseNDArray(double[] data, int[] shape, char ordering) {
this(new DoubleBuffer(data),shape,ordering);
}
public BaseNDArray(double[] data, int[] shape, int[] stride, int offset, char ordering) {
this(new DoubleBuffer(data),shape,stride,offset,ordering);
}
public BaseNDArray(float[] data, char order) {
this(new FloatBuffer(data),order);
}
public BaseNDArray(FloatBuffer floatBuffer, char order) {
this(floatBuffer,new int[]{floatBuffer.length()},Nd4j.getStrides(new int[]{floatBuffer.length()}),0,order);
}
@Override
public INDArray linearViewColumnOrder() {
return Nd4j.create(data, new int[]{length,1},offset());
}
/**
* Returns a linear view reference of shape
* 1,length(ndarray)
*
* @return the linear view of this ndarray
*/
@Override
public INDArray linearView() {
if(isVector())
return this;
if(linearView == null)
linearView = Nd4j.create(data, new int[]{length}, offset());
return linearView;
}
/**
* Create this ndarray with the given data and shape and 0 offset
* @param data the data to use
* @param shape the shape of the ndarray
*/
public BaseNDArray(float[] data, int[] shape) {
this(data,shape,0);
}
public BaseNDArray(float[] data, int[] shape, int offset) {
this(data,shape,offset, Nd4j.order());
}
/**
* Construct an ndarray of the specified shape
* with an empty data array
* @param shape the shape of the ndarray
* @param stride the stride of the ndarray
* @param offset the desired offset
*/
public BaseNDArray(int[] shape, int[] stride, int offset) {
this(new float[ArrayUtil.prod(shape)],shape,stride,offset, Nd4j.order());
}
/**
* Create the ndarray with
* the specified shape and stride and an offset of 0
* @param shape the shape of the ndarray
* @param stride the stride of the ndarray
*/
public BaseNDArray(int[] shape, int[] stride){
this(shape,stride,0);
}
public BaseNDArray(int[] shape, int offset) {
this(shape,calcStrides(shape),offset);
}
public BaseNDArray(int[] shape, char ordering) {
this(shape,0,ordering);
}
/**
* Creates a new n times mDoubleMatrix.
*
* @param newRows the number of rows (n) of the new matrix.
* @param newColumns the number of columns (m) of the new matrix.
*/
public BaseNDArray(int newRows, int newColumns) {
this(newRows,newColumns, Nd4j.order());
}
/**
* Create an ndarray from the specified slices.
* This will go through and merge all of the
* data from each slice in to one ndarray
* which will then take the specified shape
* @param slices the slices to merge
* @param shape the shape of the ndarray
*/
public BaseNDArray(List slices, int[] shape) {
this(slices,shape, Nd4j.order());
}
/**
* Create an ndarray from the specified slices.
* This will go through and merge all of the
* data from each slice in to one ndarray
* which will then take the specified shape
* @param slices the slices to merge
* @param shape the shape of the ndarray
*/
public BaseNDArray(List slices, int[] shape, int[] stride) {
this(slices,shape,stride, Nd4j.order());
}
public BaseNDArray(float[] data, int[] shape, int[] stride) {
this(data,shape,stride, Nd4j.order());
}
public BaseNDArray(float[] data, int[] shape, int[] stride, int offset) {
this(data,shape,stride,offset, Nd4j.order());
}
public BaseNDArray(float[] data) {
this.data = Nd4j.createBuffer(data);
}
public BaseNDArray(float[][] data) {
this(data.length, data[0].length);
for (int r = 0; r < rows; r++) {
assert (data[r].length == columns);
}
this.data = Nd4j.createBuffer(length);
for (int r = 0; r < rows; r++) {
for (int c = 0; c < columns; c++) {
putScalar(new int[]{r, c}, data[r][c]);
}
}
}
@Override
public void setData(float[] data) {
this.data = Nd4j.createBuffer(data);
}
@Override
public int majorStride() {
//return ordering == NDArrayFactory.C ? stride[0] : stride[stride.length - 1];r
return stride[0];
}
@Override
public int secondaryStride() {
if(stride.length >= 2) {
if(ordering == NDArrayFactory.C) {
if(isColumnVector())
return majorStride();
return stride[1];
}
else
return majorStride();
}
return majorStride();
}
/**
* Returns the number of possible vectors for a given dimension
*
* @param dimension the dimension to calculate the number of vectors for
* @return the number of possible vectors along a dimension
*/
@Override
public int vectorsAlongDimension(int dimension) {
if(dimension >= shape.length)
return length / size(shape.length - 1);
return length / size(dimension);
}
/**
* Get the vector along a particular dimension
*
* @param index the index of the vector to getScalar
* @param dimension the dimension to getScalar the vector from
* @return the vector along a particular dimension
*/
@Override
public INDArray vectorAlongDimension(int index, int dimension) {
assert dimension <= shape.length : "Invalid dimension " + dimension;
if(dimension > shape().length - 1)
dimension = shape.length - 1;
if(ordering == NDArrayFactory.C) {
if(dimension == shape.length - 1 && dimension != 0)
return Nd4j.create(data,
new int[]{shape[dimension]}
, new int[]{stride[shape.length - 1]},
offset + index * stride[dimension - 1]);
else if(dimension == 0)
return Nd4j.create(data,
new int[]{shape[dimension], 1}
, new int[]{stride[dimension], 1},
offset + index);
if(size(dimension) == 1) {
return Nd4j.create(data,
new int[]{shape[dimension], 1}
, new int[]{stride[dimension], 1},
offset + index);
}
else
return Nd4j.create(data,
new int[]{shape[dimension], 1}
, new int[]{stride[dimension], 1},
offset + index * stride[0]);
}
else if(ordering == NDArrayFactory.FORTRAN) {
if(dimension == shape.length - 1 && dimension != 0) {
if(size(dimension) == 1)
return Nd4j.create(data,
new int[]{1, shape[dimension]}
, ArrayUtil.removeIndex(stride, 0),
offset + index);
return Nd4j.create(data,
new int[]{1, shape[dimension]}
, ArrayUtil.removeIndex(stride, 0),
offset + index * stride[dimension - 1]);
}
else {
if(size(dimension) == 1)
return Nd4j.create(data,
new int[]{shape[dimension], 1}
, new int[]{stride[dimension], 1},
offset + index);
return Nd4j.create(data,
new int[]{shape[dimension], 1}
, new int[]{stride[dimension], 1},
offset + index * stride[stride.length - 1]);
}
}
throw new IllegalStateException("Illegal ordering..none declared");
}
/**
* Cumulative sum along a dimension
*
* @param dimension the dimension to perform cumulative sum along
* @return the cumulative sum along the specified dimension
*/
@Override
public INDArray cumsumi(int dimension) {
if(isVector()) {
double s = 0.0;
for (int i = 0; i < length; i++) {
if(data.dataType().equals(DataBuffer.FLOAT))
s += getDouble(i);
else
s+= getDouble(i);
putScalar(i, s);
}
}
else if(dimension == Integer.MAX_VALUE || dimension == shape.length - 1) {
INDArray flattened = ravel().dup();
double prevVal = flattened.getDouble(0);
for(int i = 1; i < flattened.length(); i++) {
double d = prevVal + flattened.getDouble(i);
flattened.putScalar(i,d);
prevVal = d;
}
return flattened;
}
else {
for(int i = 0; i < vectorsAlongDimension(dimension); i++) {
INDArray vec = vectorAlongDimension(i,dimension);
vec.cumsumi(0);
}
}
return this;
}
/**
* Cumulative sum along a dimension (in place)
*
* @param dimension the dimension to perform cumulative sum along
* @return the cumulative sum along the specified dimension
*/
@Override
public INDArray cumsum(int dimension) {
return dup().cumsumi(dimension);
}
/**
* Assign all of the elements in the given
* ndarray to this ndarray
*
* @param arr the elements to assign
* @return this
*/
@Override
public INDArray assign(INDArray arr) {
LinAlgExceptions.assertSameShape(this,arr);
INDArray other = arr.linearView();
INDArray thisArr = linearView();
for(int i = 0; i < other.length(); i++)
thisArr.putScalar(i, other.getDouble(i));
return this;
}
@Override
public INDArray putScalar(int i, double value) {
int idx = linearIndex(i);
data.put(idx,value);
return this;
}
@Override
public INDArray putScalar(int i, float value) {
int idx = linearIndex(i);
data.put(idx,value);
return this;
}
@Override
public INDArray putScalar(int i, int value) {
int idx = linearIndex(i);
data.put(idx,value);
return this;
}
@Override
public INDArray putScalar(int[] indexes, double value) {
int ix = offset;
for (int i = 0; i < shape.length; i++) {
ix += indexes[i] * stride[i];
}
data.put(ix,value);
return this;
}
@Override
public INDArray putScalar(int[] indexes, float value) {
int ix = offset;
for (int i = 0; i < shape.length; i++) {
ix += indexes[i] * stride[i];
}
data.put(ix,value);
return this;
}
@Override
public INDArray putScalar(int[] indexes, int value) {
int ix = offset;
for (int i = 0; i < shape.length; i++) {
ix += indexes[i] * stride[i];
}
data.put(ix,value);
return this;
}
/**
* Returns an ndarray with 1 if the element is epsilon equals
*
* @param other the number to compare
* @return a copied ndarray with the given
* binary conditions
*/
@Override
public INDArray eps(Number other) {
return dup().epsi(other);
}
/**
* Returns an ndarray with 1 if the element is epsilon equals
*
* @param other the number to compare
* @return a copied ndarray with the given
* binary conditions
*/
@Override
public INDArray epsi(Number other) {
INDArray linearView = linearView();
double otherVal = other.doubleValue();
for(int i = 0; i < linearView.length(); i++) {
double val = linearView.getDouble(i);
double diff = Math.abs(val - otherVal);
if(diff <= Nd4j.EPS_THRESHOLD)
linearView.putScalar(i,1);
else
linearView.putScalar(i,0);
}
return this;
}
/**
* epsilon equals than comparison:
* If the given number is less than the
* comparison number the item is 0 otherwise 1
*
* @param other the number to compare
* @return
*/
@Override
public INDArray eps(INDArray other) {
return dup().epsi(other);
}
/**
* In place epsilon equals than comparison:
* If the given number is less than the
* comparison number the item is 0 otherwise 1
*
* @param other the number to compare
* @return
*/
@Override
public INDArray epsi(INDArray other) {
INDArray linearView = linearView();
INDArray otherLinearView = other.linearView();
for(int i = 0; i < linearView.length(); i++) {
double val = linearView.getDouble(i);
double otherVal = otherLinearView.getDouble(i);
double diff = Math.abs(val - otherVal);
if(diff <= Nd4j.EPS_THRESHOLD)
linearView.putScalar(i,1);
else
linearView.putScalar(i,0);
}
return this;
}
@Override
public INDArray lt(Number other) {
return dup().lti(other);
}
@Override
public INDArray lti(Number other) {
return lti(Nd4j.scalar(other));
}
@Override
public INDArray eq(Number other) {
return dup().eqi(other);
}
@Override
public INDArray eqi(Number other) {
return eqi(Nd4j.scalar(other));
}
@Override
public INDArray gt(Number other) {
return dup().gti(other);
}
@Override
public INDArray gti(Number other) {
return gti(Nd4j.scalar(other));
}
@Override
public INDArray lt(INDArray other) {
return dup().lti(other);
}
@Override
public INDArray lti(INDArray other) {
INDArray linear = linearView();
INDArray otherLinear = other.linearView();
for(int i = 0; i < linear.length(); i++) {
linear.putScalar(i,linear.getDouble(i) < otherLinear.getDouble(i) ? 1 : 0);
}
return this;
}
@Override
public INDArray neq(INDArray other) {
return dup().neqi(other);
}
@Override
public INDArray neqi(INDArray other) {
INDArray linear = linearView();
INDArray otherLinear = other.linearView();
for(int i = 0; i < linear.length(); i++) {
linear.putScalar(i,linear.getDouble(i) != otherLinear.getDouble(i) ? 1 : 0);
}
return this;
}
@Override
public INDArray eq(INDArray other) {
return dup().eqi(other);
}
@Override
public INDArray eqi(INDArray other) {
INDArray linear = linearView();
INDArray otherLinear = other.linearView();
for(int i = 0; i < linear.length(); i++) {
linear.putScalar(i,linear.getDouble(i) == otherLinear.getDouble(i) ? 1 : 0);
}
return this;
}
@Override
public INDArray gt(INDArray other) {
return dup().gti(other);
}
@Override
public INDArray gti(INDArray other) {
INDArray linear = linearView();
INDArray otherLinear = other.linearView();
for(int i = 0; i < linear.length(); i++) {
linear.putScalar(i,linear.getDouble(i) > otherLinear.getDouble(i) ? 1 : 0);
}
return this;
}
/**
* Negate each element.
*/
@Override
public INDArray neg() {
return dup().negi();
}
/**
* Negate each element (in-place).
*/
@Override
public INDArray negi() {
return Transforms.neg(this);
}
@Override
public INDArray rdiv(Number n, INDArray result) {
return dup().rdivi(n,result);
}
@Override
public INDArray rdivi(Number n, INDArray result) {
INDArray linear = linearView();
INDArray resultLinear = result.linearView();
for (int i = 0; i < length; i++) {
resultLinear.putScalar(i, n.floatValue() / linear.getDouble(i));
}
return result;
}
@Override
public INDArray rsub(Number n, INDArray result) {
return dup().rsubi(n,result);
}
@Override
public INDArray rsubi(Number n, INDArray result) {
INDArray linear = linearView();
INDArray resultLinear = result.linearView();
double val = n.doubleValue();
for (int i = 0; i < length; i++) {
resultLinear.putScalar(i, val - linear.getDouble(i));
}
return result;
}
@Override
public INDArray div(Number n, INDArray result) {
return dup().divi(n,result);
}
@Override
public INDArray divi(Number n, INDArray result) {
INDArray linear = linearView();
INDArray resultLinear = result.linearView();
double val = n.doubleValue();
for (int i = 0; i < length; i++) {
resultLinear.putScalar(i, linear.getDouble(i) / val);
}
return result;
}
@Override
public INDArray mul(Number n, INDArray result) {
return dup().muli(n,result);
}
@Override
public INDArray muli(Number n, INDArray result) {
INDArray linear = linearView();
INDArray resultLinear = result.linearView();
double val = n.doubleValue();
for (int i = 0; i < length; i++) {
resultLinear.putScalar(i, linear.getDouble(i) * val);
}
return result;
}
@Override
public INDArray sub(Number n, INDArray result) {
return dup().subi(n,result);
}
@Override
public INDArray subi(Number n, INDArray result) {
INDArray linear = linearView();
INDArray resultLinear = result.linearView();
double val = n.doubleValue();
for (int i = 0; i < length; i++) {
resultLinear.putScalar(i, linear.getDouble(i) - val);
}
return result;
}
@Override
public INDArray add(Number n, INDArray result) {
return dup().addi(n,result);
}
@Override
public INDArray addi(Number n, INDArray result) {
INDArray linear = linearView();
INDArray resultLinear = result.linearView();
double val = n.doubleValue();
for (int i = 0; i < length; i++) {
resultLinear.putScalar(i, linear.getDouble(i) + val);
}
return result;
}
/**
* Returns the element at the specified row/column
* This will throw an exception if the
*
* @param row the row of the element to return
* @param column the row of the element to return
* @return a scalar indarray of the element at this index
*/
@Override
public INDArray getScalar(int row, int column) {
return getScalar(new int[]{row,column});
}
@Override
public INDArray dup() {
INDArray ret = Nd4j.create(shape());
INDArray linear = linearView();
INDArray retLinear = ret.linearView();
for(int i = 0; i < ret.length(); i++) {
retLinear.putScalar(i,linear.getDouble(i));
}
return ret;
}
/**
* Returns the elements at the the specified indices
*
* @param indices the indices to getScalar
* @return the array with the specified elements
*/
@Override
public int getInt(int... indices) {
int ix = offset;
for (int i = 0; i< indices.length; i++)
ix += indices[i] * stride[i];
return data.getInt(ix);
}
/**
* Returns the elements at the the specified indices
*
* @param indices the indices to getScalar
* @return the array with the specified elements
*/
@Override
public double getDouble(int... indices) {
int ix = offset;
for (int i = 0; i< indices.length; i++)
ix += indices[i] * stride[i];
return data.getDouble(ix);
}
/**
* Returns the elements at the the specified indices
*
* @param indices the indices to getScalar
* @return the array with the specified elements
*/
@Override
public float getFloat(int... indices) {
int ix = offset;
for (int i = 0; i< indices.length; i++)
ix += indices[i] * stride[i];
return data.getFloat(ix);
}
/**
* Test whether a matrix is scalar.
*/
@Override
public boolean isScalar() {
if(shape.length == 0)
return true;
else if(shape.length == 1 && shape[0] == 1)
return true;
else if(shape.length >= 2) {
for(int i = 0; i < shape.length; i++)
if(shape[i] != 1)
return false;
}
return length == 1;
}
/**
* Inserts the element at the specified index
*
* @param indices the indices to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int[] indices, INDArray element) {
if(!element.isScalar())
throw new IllegalArgumentException("Unable to insert anything but a scalar");
int ix = offset;
for (int i = 0; i< indices.length; i++)
ix += indices[i] * stride[i];
data.put(ix,element.getDouble(0));
return this;
}
/**
* Inserts the element at the specified index
*
* @param i the row insert into
* @param j the column to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int i, int j, INDArray element) {
return put(new int[]{i,j},element);
}
/**
* Inserts the element at the specified index
*
* @param i the row insert into
* @param j the column to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int i, int j, Number element) {
return put(i,j, Nd4j.scalar(element));
}
/**
* Assigns the given matrix (put) to the specified slice
* @param slice the slice to assign
* @param put the slice to applyTransformToDestination
* @return this for chainability
*/
@Override
public INDArray putSlice(int slice,INDArray put) {
if(isScalar()) {
assert put.isScalar() : "Invalid dimension. Can only insert a scalar in to another scalar";
put(0,put.getScalar(0));
return this;
}
else if(isVector()) {
assert put.isScalar() || put.isVector() &&
put.length() == length() : "Invalid dimension on insertion. Can only insert scalars input vectors";
if(put.isScalar())
putScalar(slice,put.getDouble(0));
else
for(int i = 0; i < length(); i++)
putScalar(i,put.getDouble(i));
return this;
}
assertSlice(put,slice);
INDArray view = slice(slice);
if(put.isScalar())
putScalar(slice, put.getDouble(0));
else if(put.isVector())
for(int i = 0; i < put.length(); i++)
view.putScalar(i, put.getDouble(i));
else if(put.shape().length == 2)
for(int i = 0; i < put.rows(); i++)
for(int j = 0; j < put.columns(); j++)
view.put(i,j,put.getDouble(i, j));
else {
assert put.slices() == view.slices() : "Slices must be equivalent.";
for(int i = 0; i < put.slices(); i++)
view.slice(i).putSlice(i,view.slice(i));
}
return this;
}
protected void assertSlice(INDArray put,int slice) {
assert slice <= slices() : "Invalid slice specified " + slice;
int[] sliceShape = put.shape();
int[] requiredShape = ArrayUtil.removeIndex(shape(),0);
//no need to compare for scalar; primarily due to shapes either being [1] or length 0
if(put.isScalar())
return;
assert Shape.shapeEquals(sliceShape,requiredShape) : String.format("Invalid shape size of %s . Should have been %s ",Arrays.toString(sliceShape),Arrays.toString(requiredShape));
}
/**
* Returns true if this ndarray is 2d
* or 3d with a singleton element
* @return true if the element is a matrix, false otherwise
*/
public boolean isMatrix() {
return (shape().length == 2
&& (shape[0] != 1 && shape[1] != 1));
}
/**
*
* http://docs.scipy.org/doc/numpy/reference/generated/numpy.ufunc.reduce.html
* @param op the operation to do
* @param dimension the dimension to return from
* @return the results of the reduce (applying the operation along the specified
* dimension)
*/
@Override
public INDArray reduce(Ops.DimensionOp op,int dimension) {
if (isScalar())
return this;
if (isVector())
return Nd4j.scalar(reduceVector(op, this));
int[] shape = ArrayUtil.removeIndex(this.shape, dimension);
if (dimension == 0) {
double[] data2 = new double[ArrayUtil.prod(shape)];
int dataIter = 0;
//iterating along the dimension is relative to the number of slices
//in the return dimension
int numTimes = ArrayUtil.prod(shape);
for (int offset = this.offset; offset < numTimes; offset++) {
double reduce = op(dimension, offset, op);
data2[dataIter++] = reduce;
}
return Nd4j.create(data2, shape);
} else {
double[] data2 = new double[ArrayUtil.prod(shape)];
int dataIter = 0;
//want the milestone to slice[1] and beyond
int[] sliceIndices = endsForSlices();
int currOffset = 0;
//iterating along the dimension is relative to the number of slices
//in the return dimension
int numTimes = ArrayUtil.prod(shape);
for (int offset = this.offset; offset < numTimes; offset++) {
if (dataIter >= data2.length || currOffset >= sliceIndices.length)
break;
//do the operation,, and look for whether it exceeded the current slice
IterationResult pair = op(dimension, offset, op, sliceIndices[currOffset]);
//append the result
double reduce = pair.getResult();
data2[dataIter++] = reduce;
//go to next slice and iterate over that
if (pair.isNextSlice()) {
//will update to next step
offset = sliceIndices[currOffset];
numTimes += sliceIndices[currOffset];
currOffset++;
}
}
return Nd4j.create(data2, shape);
}
}
//get one result along one dimension based on the given offset
public DimensionSlice vectorForDimensionAndOffset(int dimension, int offset) {
if(isScalar() && dimension == 0 && offset == 0)
return new DimensionSlice(false,getScalar(offset),new int[]{offset});
//need whole vector
else if (isVector()) {
if(dimension == 0) {
int[] indices = new int[length];
for(int i = 0; i < indices.length; i++)
indices[i] = i;
return new DimensionSlice(false,dup(),indices);
}
else if(dimension == 1)
return new DimensionSlice(false,getScalar(offset),new int[]{offset});
else
throw new IllegalArgumentException("Illegal dimension for vector " + dimension);
}
else {
int count = 0;
List indices = new ArrayList<>();
INDArray ret = Nd4j.create(new int[]{shape[dimension]});
for(int j = offset; count < this.shape[dimension]; j+= this.stride[dimension]) {
double d = data.getDouble(j);
ret.putScalar(count++, d);
indices.add(j);
}
return new DimensionSlice(false,ret,ArrayUtil.toArray(indices));
}
}
//getFromOrigin one result along one dimension based on the given offset
private IterationResult op(int dimension, int offset, Ops.DimensionOp op,int currOffsetForSlice) {
double[] dim = new double[this.shape[dimension]];
int count = 0;
boolean newSlice = false;
for(int j = offset; count < dim.length; j+= this.stride[dimension]) {
double d = data.getDouble(j);
dim[count++] = d;
if(j >= currOffsetForSlice)
newSlice = true;
}
return new IterationResult(reduceVector(op, Nd4j.create(dim)),newSlice);
}
//getFromOrigin one result along one dimension based on the given offset
private double op(int dimension, int offset, Ops.DimensionOp op) {
double[] dim = new double[this.shape[dimension]];
int count = 0;
for(int j = offset; count < dim.length; j+= this.stride[dimension]) {
double d = data.getDouble(j);
dim[count++] = d;
}
return reduceVector(op, Nd4j.create(dim));
}
@Override
public int index(int row, int column) {
if(!isMatrix()) {
if(isColumnVector()) {
int idx = linearIndex(row);
return idx;
}
else if (isRowVector()) {
int idx = linearIndex(column);
return idx;
}
else
throw new IllegalStateException("Unable to getFromOrigin row/column from a non matrix");
}
return offset + (row * stride[0] + column * stride[1]);
}
protected INDArray newShape(int[] newShape,char ordering) {
if(Arrays.equals(newShape,this.shape()))
return this;
else if(Shape.isVector(newShape) && isVector()) {
if(isRowVector() && Shape.isColumnVectorShape(newShape)) {
return Nd4j.create(data,newShape,new int[]{stride[0],1},offset);
}
else if(isColumnVector() && Shape.isRowVectorShape(newShape)) {
return Nd4j.create(data,newShape,new int[]{stride[1]},offset);
}
}
INDArray newCopy = this;
int[] newStrides = null;
//create a new copy of the ndarray
if(shape().length > 1 && ((ordering == NDArrayFactory.C && this.ordering != NDArrayFactory.C) ||
(ordering == NDArrayFactory.FORTRAN && this.ordering != NDArrayFactory.FORTRAN))) {
newStrides = noCopyReshape(newShape,ordering);
if(newStrides == null) {
newCopy = Nd4j.create(shape(),ordering);
for(int i = 0; i < vectorsAlongDimension(0); i++) {
INDArray copyFrom = vectorAlongDimension(i,0);
INDArray copyTo = newCopy.vectorAlongDimension(i,0);
for(int j = 0; j < copyFrom.length(); j++) {
copyTo.putScalar(j,copyFrom.getDouble(i));
}
}
}
}
//needed to copy data
if(newStrides == null)
newStrides = Nd4j.getStrides(newShape);
if(this instanceof IComplexNDArray)
return Nd4j.createComplex(newCopy.data(),newShape,newStrides,offset);
return Nd4j.create(newCopy.data(),newShape,newStrides,offset);
}
/**
* Return the new strides based on the shape and ordering or null
* if we can't do a reshape
* @param newShape the new shape
* @param ordering the ordering of the new shape
* @return the new strides or null if we can't reshape
*/
protected int[] noCopyReshape(int[] newShape,char ordering) {
List oldDims = new ArrayList<>();
List oldStrides = new ArrayList<>();
for (int i = 0; i < shape.length; i++) {
if (size(i) != 1) {
oldDims.add(size(i));
oldStrides.add(stride[i]);
}
}
int np = 1;
for (int ni = 0; ni < newShape.length; ni++) {
np *= newShape[ni];
}
int op = 1;
for (int oi = 0; oi < oldDims.size(); oi++) {
op *= oldDims.get(oi);
}
if (np != op) {
/* different total sizes; no hope */
return null;
}
if (np == 0) {
/* the current code does not handle 0-sized arrays, so give up */
return null;
}
/* oi to oj and ni to nj give the axis ranges currently worked with */
int oi = 0;
int oj = 1;
int ni = 0;
int nj = 1;
List newStrides = new ArrayList<>();
while (ni < newShape.length && oi < oldDims.size()) {
np = newShape[ni];
op = oldDims.get(oi);
while (np != op) {
if (np < op)
np *= newShape[nj++];
else
op *= oldDims.get(oj++);
}
/* Check whether the original axes can be combined */
for (int ok = oi; ok < oj - 1; ok++) {
if (ordering == NDArrayFactory.FORTRAN) {
if (oldStrides.get(ok + 1) != oldDims.get(ok) * oldStrides.get(ok)) {
/* not contiguous enough */
return null;
}
} else {
/* C order */
if (oldStrides.get(ok) != oldDims.get(ok + 1) * oldStrides.get(ok + 1)) {
/* not contiguous enough */
return null;
}
}
}
/* Calculate new strides for all axes currently worked with */
if (ordering == NDArrayFactory.FORTRAN) {
newStrides.set(ni, oldStrides.get(oi));
for (int nk = ni + 1; nk < nj; nk++) {
newStrides.set(nk, newStrides.get(nk - 1) * newShape[nk - 1]);
}
} else {
/* C order */
newStrides.set(nj - 1, oldStrides.get(oj - 1));
for (int nk = nj - 1; nk > ni; nk--) {
newStrides.set(nk - 1, newStrides.get(nk) * newShape[nk]);
}
}
ni = nj++;
oi = oj++;
}
/*
* Set strides corresponding to trailing 1s of the new shape.
*/
int lastStride;
if (ni >= 1) {
lastStride = newStrides.get(ni - 1);
} else {
lastStride = length;
}
if (ordering == NDArrayFactory.FORTRAN) {
lastStride *= newShape[ni - 1];
}
for (int nk = ni; nk < newShape.length; nk++) {
newStrides.set(nk, lastStride);
}
return ArrayUtil.toArray(newStrides);
}
protected double reduceVector(Ops.DimensionOp op,INDArray vector) {
switch(op) {
case SUM:
return vector.sum(Integer.MAX_VALUE).getDouble(0);
case MEAN:
return vector.mean(Integer.MAX_VALUE).getDouble(0);
case MIN:
return vector.min(Integer.MAX_VALUE).getDouble(0);
case MAX:
return vector.max(Integer.MAX_VALUE).getDouble(0);
case NORM_1:
return vector.norm1(Integer.MAX_VALUE).getDouble(0);
case NORM_2:
return vector.norm2(Integer.MAX_VALUE).getDouble(0);
case NORM_MAX:
return vector.normmax(Integer.MAX_VALUE).getDouble(0);
default: throw new IllegalArgumentException("Illegal operation");
}
}
/**
* Returns the squared (Euclidean) distance.
*/
@Override
public double squaredDistance(INDArray other) {
double sd = 0.0;
for (int i = 0; i < length; i++) {
double d = getDouble(i) - other.getDouble(i);
sd += d * d;
}
return sd;
}
/**
* Returns the (euclidean) distance.
*/
@Override
public double distance2(INDArray other) {
return Math.sqrt(squaredDistance(other));
}
/**
* Returns the (1-norm) distance.
*/
@Override
public double distance1(INDArray other) {
return other.sub(this).sum(Integer.MAX_VALUE).getDouble(0);
}
@Override
public INDArray put(NDArrayIndex[] indices, INDArray element) {
if(Indices.isContiguous(indices)) {
INDArray get = get(indices);
INDArray linear = get.linearView();
if(element.isScalar()) {
for(int i = 0; i < linear.length(); i++) {
linear.putScalar(i,element.getDouble(0));
}
}
if(Shape.shapeEquals(element.shape(),get.shape()) || element.length() <= get.length()) {
INDArray elementLinear = element.linearView();
for(int i = 0; i < elementLinear.length(); i++) {
linear.putScalar(i,elementLinear.getDouble(i));
}
}
}
else {
if(isVector()) {
assert indices.length == 1 : "Indices must only be of length 1.";
assert element.isScalar() || element.isVector() : "Unable to assign elements. Element is not a vector.";
assert indices[0].length() == element.length() : "Number of specified elements in index does not match length of element.";
int[] assign = indices[0].indices();
for(int i = 0; i < element.length(); i++) {
putScalar(assign[i],element.getDouble(i));
}
return this;
}
if(element.isVector())
slice(indices[0].indices()[0]).put(Arrays.copyOfRange(indices,1,indices.length),element);
else {
for(int i = 0; i < element.slices(); i++) {
INDArray slice = slice(indices[0].indices()[i]);
slice.put(Arrays.copyOfRange(indices,1,indices.length),element.slice(i));
}
}
}
return this;
}
@Override
public INDArray put(NDArrayIndex[] indices, Number element) {
return put(indices, Nd4j.scalar(element));
}
/**
* Iterate over every column of every slice
* @param op the operation to apply
*/
@Override
public void iterateOverAllColumns(SliceOp op) {
if(isVector())
op.operate(this);
else if(isMatrix()) {
for(int i = 0; i < columns(); i++) {
op.operate(getColumn(i));
}
}
else {
for(int i = 0; i < slices(); i++) {
slice(i).iterateOverAllRows(op);
}
}
}
/**
* Iterate over every row of every slice
* @param op the operation to apply
*/
@Override
public void iterateOverAllRows(SliceOp op) {
if(isVector())
op.operate(this);
else if(isMatrix()) {
for(int i = 0; i < rows(); i++) {
op.operate(getRow(i));
}
}
else {
for(int i = 0; i < slices(); i++) {
slice(i).iterateOverAllRows(op);
}
}
}
/**
* Mainly here for people coming from numpy.
* This is equivalent to a call to permute
* @param dimension the dimension to swap
* @param with the one to swap it with
* @return the swapped axes view
*/
@Override
public INDArray swapAxes(int dimension,int with) {
int[] shape = ArrayUtil.range(0,shape().length);
shape[dimension] = with;
shape[with] = dimension;
return permute(shape);
}
/**
* Gives the indices for the ending of each slice
* @return the off sets for the beginning of each slice
*/
@Override
public int[] endsForSlices() {
int[] ret = new int[slices()];
int currOffset = offset + stride[0] - 1;
for(int i = 0; i < slices(); i++) {
ret[i] = currOffset;
currOffset += stride[0];
}
return ret;
}
@Override
public DataBuffer data() {
return data;
}
@Override
public void setData(DataBuffer data) {
this.data = data;
}
/**
* Number of slices: aka shape[0]
* @return the number of slices
* for this nd array
*/
@Override
public int slices() {
if(shape.length < 1)
return 0;
return shape[0];
}
@Override
public INDArray subArray(int[] offsets, int[] shape,int[] stride) {
int n = shape.length;
if(shape.length < 1)
return Nd4j.create(Nd4j.createBuffer(shape));
if (offsets.length != n)
throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets));
if (shape.length != n)
throw new IllegalArgumentException("Invalid shape " + Arrays.toString(shape));
if (Arrays.equals(shape, this.shape)) {
if (ArrayUtil.isZero(offsets)) {
return this;
} else {
throw new IllegalArgumentException("Invalid subArray offsets");
}
}
int offset = this.offset + ArrayUtil.dotProduct(offsets, this.stride);
return Nd4j.create(
data
, Arrays.copyOf(shape, shape.length)
, stride
, offset, ordering
);
}
@Override
public INDArray cond(Condition condition) {
return dup().condi(condition);
}
@Override
public INDArray condi(Condition condition) {
INDArray linear = linearView();
for(int i = 0 ;i < length(); i++) {
boolean met = condition.apply(linear.getDouble(i));
linear.putScalar(i,met ? 1 : 0);
}
return this;
}
/**
* Iterate along a dimension.
* This encapsulates the process of sum, mean, and other processes
* take when iterating over a dimension.
* @param dimension the dimension to iterate over
* @param op the operation to apply
* @param modify whether to modify this array while iterating
*/
@Override
public void iterateOverDimension(int dimension,SliceOp op,boolean modify) {
if(dimension >= shape.length)
throw new IllegalArgumentException("Unable to remove dimension " + dimension + " was >= shape length");
if(isScalar()) {
if(dimension > 0)
throw new IllegalArgumentException("Dimension must be 0 for a scalar");
else {
DimensionSlice slice = this.vectorForDimensionAndOffset(0,0);
op.operate(slice);
if(modify && slice.getIndices() != null) {
INDArray result = (INDArray) slice.getResult();
for(int i = 0; i < slice.getIndices().length; i++) {
data.put(slice.getIndices()[i],result.getDouble(i));
}
}
}
}
else if(isVector()) {
if(dimension == 0) {
DimensionSlice slice = this.vectorForDimensionAndOffset(0,0);
op.operate(slice);
if(modify && slice.getIndices() != null) {
INDArray result = (INDArray) slice.getResult();
for(int i = 0; i < slice.getIndices().length; i++) {
data.put(slice.getIndices()[i],result.getDouble(i));
}
}
}
else if(dimension == 1) {
for(int i = 0; i < length; i++) {
DimensionSlice slice = vectorForDimensionAndOffset(dimension,i);
op.operate(slice);
if(modify && slice.getIndices() != null) {
INDArray result = (INDArray) slice.getResult();
for(int j = 0; j < slice.getIndices().length; j++) {
data.put(slice.getIndices()[j],result.getDouble(j));
}
}
}
}
else
throw new IllegalArgumentException("Illegal dimension for vector " + dimension);
}
else {
int vectors = vectorsAlongDimension(dimension);
for(int i = 0; i < vectors; i++) {
INDArray vector = vectorAlongDimension(i,dimension);
op.operate(vector);
}
}
}
@Override
public void setStride(int[] stride) {
this.stride = stride;
}
protected void initShape(int[] shape) {
this.shape = shape;
if(this.shape.length == 1) {
rows = 1;
columns = this.shape[0];
}
else if(this.shape().length == 2) {
if(shape[0] == 1) {
this.shape = new int[1];
this.shape[0] = shape[1];
rows = 1;
columns = shape[1];
//weird case here: when its fortran contiguous, we have scencrios
//where we need to retain the non zero stride of the ndarray.
//this is in response to a strange scenario that happens in subArray
//hopefully we can work out something better eventually
if(stride != null && ordering == NDArrayFactory.FORTRAN)
this.stride = new int[]{ArrayUtil.nonOneStride(this.stride)};
}
else {
rows = shape[0];
columns = shape[1];
}
}
//default row vector
else if(this.shape.length == 1) {
columns = this.shape[0];
rows = 1;
}
//null character
if(this.ordering == '\u0000')
this.ordering = Nd4j.order();
this.length = ArrayUtil.prod(this.shape);
if(this.stride == null) {
if(ordering == NDArrayFactory.FORTRAN)
this.stride = ArrayUtil.calcStridesFortran(shape);
else
this.stride = ArrayUtil.calcStrides(this.shape);
}
//recalculate stride: this should only happen with row vectors
if(this.stride.length != this.shape.length) {
if(ordering == NDArrayFactory.FORTRAN)
this.stride = ArrayUtil.calcStridesFortran(this.shape);
else
this.stride = ArrayUtil.calcStrides(this.shape);
}
}
@Override
public INDArray getScalar(int i) {
if(!isVector() && !isScalar())
throw new IllegalArgumentException("Unable to do linear indexing with dimensions greater than 1");
int idx = linearIndex(i);
return Nd4j.scalar(data.getDouble(idx));
}
protected void asserColumnVector(INDArray column) {
assert column.isColumnVector() || column.columns() == columns() && column.rows() == 1: "Must only add a row vector";
assert column.length() == rows() || column.columns() == columns() && column.rows() == 1: "Illegal row vector must have the same length as the number of rows in this ndarray";
}
/**
* Do a row wise op (a,s,m,d)
* a : add
* s : subtract
* m : multiply
* d : divide
* h : reverse subtraction
* t : reverse division
* @param columnVector the column vector
* @param operation the operation
* @return
*/
protected INDArray doColumnWise(INDArray columnVector,char operation) {
asserColumnVector(columnVector);
for(int i = 0; i < columns(); i++) {
INDArray slice = slice(i,0);
switch(operation) {
case 'a' : slice.addi(columnVector); break;
case 's' : slice.subi(columnVector); break;
case 'm' : slice.muli(columnVector); break;
case 'd' : slice.divi(columnVector); break;
case 'h' : slice.rsubi(columnVector); break;
case 't' : slice.rdivi(columnVector); break;
}
}
return this;
}
protected void assertRowVector(INDArray rowVector) {
assert rowVector.isRowVector() || rowVector.rows() == rows() && rowVector.columns() == 1: "Must only add a row vector";
assert rowVector.length() == columns() || rowVector.rows() == rows() && rowVector.columns() == 1: "Illegal row vector must have the same length as the number of rows in this ndarray";
}
/**
* Do a row wise op (a,s,m,d)
* a : add
* s : subtract
* m : multiply
* d : divide
* h : reverse subtraction
* t : reverse division
* @param rowVector the row vector
* @param operation the operation
* @return
*/
protected INDArray doRowWise(INDArray rowVector,char operation) {
assertRowVector(rowVector);
for(int i = 0; i < rows(); i++) {
switch(operation) {
case 'a' : getRow(i).addi(rowVector); break;
case 's' : getRow(i).subi(rowVector); break;
case 'm' : getRow(i).muli(rowVector); break;
case 'd' : getRow(i).divi(rowVector); break;
case 'h' : getRow(i).rsubi(rowVector); break;
case 't' : getRow(i).rdivi(rowVector); break;
}
}
return this;
}
@Override
public INDArray rdiviColumnVector(INDArray columnVector) {
return doColumnWise(columnVector,'t');
}
@Override
public INDArray rdivColumnVector(INDArray columnVector) {
return dup().rdiviColumnVector(columnVector);
}
@Override
public INDArray rdiviRowVector(INDArray rowVector) {
return doRowWise(rowVector,'t');
}
@Override
public INDArray rdivRowVector(INDArray rowVector) {
return dup().rdiviRowVector(rowVector);
}
@Override
public INDArray rsubiColumnVector(INDArray columnVector) {
return doColumnWise(columnVector,'h');
}
@Override
public INDArray rsubColumnVector(INDArray columnVector) {
return dup().rsubiColumnVector(columnVector);
}
@Override
public INDArray rsubiRowVector(INDArray rowVector) {
return doRowWise(rowVector,'h');
}
@Override
public INDArray rsubRowVector(INDArray rowVector) {
return dup().rsubiRowVector(rowVector);
}
/**
* Inserts the element at the specified index
*
* @param i the index insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int i, INDArray element) {
if(element == null)
throw new IllegalArgumentException("Unable to insert null element");
assert element.isScalar() : "Unable to insert non scalar element";
int idx = linearIndex(i);
data.put(idx,element.getDouble(0));
return this;
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray diviColumnVector(INDArray columnVector) {
return doColumnWise(columnVector,'d');
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray divColumnVector(INDArray columnVector) {
return dup().diviColumnVector(columnVector);
}
/**
* In place addition of a column vector
*
* @param rowVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray diviRowVector(INDArray rowVector) {
return doRowWise(rowVector,'d');
}
/**
* In place addition of a column vector
*
* @param rowVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray divRowVector(INDArray rowVector) {
return dup().diviRowVector(rowVector);
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray muliColumnVector(INDArray columnVector) {
return doColumnWise(columnVector,'m');
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray mulColumnVector(INDArray columnVector) {
return dup().muliColumnVector(columnVector);
}
/**
* In place addition of a column vector
*
* @param rowVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray muliRowVector(INDArray rowVector) {
return doRowWise(rowVector,'m');
}
/**
* In place addition of a column vector
*
* @param rowVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray mulRowVector(INDArray rowVector) {
return dup().muliRowVector(rowVector);
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray subiColumnVector(INDArray columnVector) {
return doColumnWise(columnVector,'s');
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray subColumnVector(INDArray columnVector) {
return dup().subiColumnVector(columnVector);
}
/**
* In place addition of a column vector
*
* @param rowVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray subiRowVector(INDArray rowVector) {
return doRowWise(rowVector,'s');
}
/**
* In place addition of a column vector
*
* @param rowVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray subRowVector(INDArray rowVector) {
return dup().subiRowVector(rowVector);
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray addiColumnVector(INDArray columnVector) {
return doColumnWise(columnVector,'a');
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray addColumnVector(INDArray columnVector) {
return dup().addiColumnVector(columnVector);
}
/**
* In place addition of a column vector
*
* @param rowVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray addiRowVector(INDArray rowVector) {
return doRowWise(rowVector,'a');
}
/**
* In place addition of a column vector
*
* @param rowVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray addRowVector(INDArray rowVector) {
return dup().addiRowVector(rowVector);
}
/**
* Perform a copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmul(INDArray other) {
int[] shape = {rows(),other.columns()};
char order = Nd4j.factory().order();
boolean switchedOrder = false;
if(order != NDArrayFactory.FORTRAN) {
Nd4j.factory().setOrder(NDArrayFactory.FORTRAN);
switchedOrder = true;
}
INDArray result = Nd4j.create(shape);
if(switchedOrder && order != NDArrayFactory.FORTRAN)
Nd4j.factory().setOrder(NDArrayFactory.C);
return mmuli(other,result);
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmul(INDArray other, INDArray result) {
return dup().mmuli(other,result);
}
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override
public INDArray div(INDArray other) {
return dup().divi(other);
}
/**
* copy (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override
public INDArray div(INDArray other, INDArray result) {
return dup().divi(other,result);
}
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the addition
*/
@Override
public INDArray mul(INDArray other) {
return dup().muli(other);
}
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override
public INDArray mul(INDArray other, INDArray result) {
return dup().muli(other,result);
}
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override
public INDArray sub(INDArray other) {
return dup().subi(other);
}
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @param result the result ndarray
* @return the result of the subtraction
*/
@Override
public INDArray sub(INDArray other, INDArray result) {
return dup().subi(other,result);
}
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override
public INDArray add(INDArray other) {
return dup().addi(other);
}
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override
public INDArray add(INDArray other, INDArray result) {
return dup().addi(other,result);
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other) {
return dup().mmuli(other,this);
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other, INDArray result) {
INDArray otherArray = other;
INDArray resultArray = result;
if(other.shape().length > 2) {
for(int i = 0; i < other.slices(); i++) {
result.putSlice(i,slice(i).mmul(other.slice(i)));
}
return result;
}
LinAlgExceptions.assertMultiplies(this,other);
if (other.isScalar()) {
return muli(otherArray.getDouble(0), resultArray);
}
if (isScalar()) {
return otherArray.muli(getDouble(0), resultArray);
}
/* check sizes and resize if necessary */
//assertMultipliesWith(other);
if (result == this || result == other) {
/* actually, blas cannot do multiplications in-place. Therefore, we will fake by
* allocating a temporary object on the side and copy the result later.
*/
INDArray temp = Nd4j.create(resultArray.shape(), ArrayUtil.calcStridesFortran(resultArray.shape()));
if (otherArray.columns() == 1) {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().gemv(1.0, this, otherArray, 0.0, temp);
else
Nd4j.getBlasWrapper().gemv(1.0f,this,otherArray,0.0f,temp);
} else {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().gemm(1.0, this, otherArray, 0.0, temp);
else
Nd4j.getBlasWrapper().gemm(1.0f,this,otherArray,0.0f,temp);
}
Nd4j.getBlasWrapper().copy(temp, resultArray);
} else {
if (otherArray.columns() == 1)
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().gemv(1.0, this, otherArray, 0.0, resultArray);
else
Nd4j.getBlasWrapper().gemv(1.0f,this,otherArray,0.0f,resultArray);
else {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().gemm(1.0, this, otherArray, 0.0, resultArray);
else
Nd4j.getBlasWrapper().gemm(1.0f,this,otherArray,0.0f,resultArray);
}
}
return resultArray;
}
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override
public INDArray divi(INDArray other) {
return divi(other,this);
}
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override
public INDArray divi(INDArray other, INDArray result) {
if (other.isScalar()) {
return divi(other.getDouble(0), result);
}
if (isScalar()) {
return other.divi(getDouble(0), result);
}
INDArray otherLinear = other.linearView();
INDArray resultLinear = result.linearView();
INDArray linearView = linearView();
for (int i = 0; i < length; i++) {
resultLinear.putScalar(i, linearView.getDouble(i) / otherLinear.getDouble(i));
}
return result;
}
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the addition
*/
@Override
public INDArray muli(INDArray other) {
return muli(other,this);
}
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override
public INDArray muli(INDArray other, INDArray result) {
if (other.isScalar()) {
return muli(other.getDouble(0), result);
}
if (isScalar()) {
return other.muli(getDouble(0), result);
}
INDArray otherLinear = other.linearView();
INDArray resultLinear = result.linearView();
INDArray linearView = linearView();
for (int i = 0; i < length; i++) {
resultLinear.putScalar(i, linearView.getDouble(i) * otherLinear.getDouble(i));
}
return result;
}
/**
* in place subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override
public INDArray subi(INDArray other) {
return subi(other,this);
}
/**
* in place subtraction of two matrices
*
* @param other the second ndarray to subtract
* @param result the result ndarray
* @return the result of the subtraction
*/
@Override
public INDArray subi(INDArray other, INDArray result) {
if (other.isScalar()) {
return subi(other.getDouble(0), result);
}
if (isScalar()) {
return other.rsubi(getDouble(0), result);
}
if (result == this) {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().axpy(-1.0, other, result);
else
Nd4j.getBlasWrapper().axpy(-1.0f,other,result);
}
else if (result == other) {
if(data.dataType().equals(DataBuffer.DOUBLE)) {
Nd4j.getBlasWrapper().scal(-1.0, result);
Nd4j.getBlasWrapper().axpy(1.0, this, result);
}
else {
Nd4j.getBlasWrapper().scal(-1.0f, result);
Nd4j.getBlasWrapper().axpy(1.0f, this, result);
}
}
else {
if(data.dataType().equals(DataBuffer.FLOAT)) {
Nd4j.getBlasWrapper().copy(this, result);
Nd4j.getBlasWrapper().axpy(-1.0f, other, result);
}
else {
Nd4j.getBlasWrapper().copy(this, result);
Nd4j.getBlasWrapper().axpy(-1.0, other, result);
}
}
return result;
}
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override
public INDArray addi(INDArray other) {
return addi(other,this);
}
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override
public INDArray addi(INDArray other, INDArray result) {
if (other.isScalar()) {
return result.addi(other.getDouble(0),result);
}
if (isScalar()) {
return other.addi(getDouble(0), result);
}
if (result == this) {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().axpy(1.0, other, result);
else
Nd4j.getBlasWrapper().axpy(1.0f,other,result);
}
else if (result == other) {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().axpy(1.0, this, result);
else
Nd4j.getBlasWrapper().axpy(1.0f,this,result);
}
else {
INDArray resultLinear = result.linearView();
INDArray otherLinear = other.linearView();
INDArray linear = linearView();
for(int i = 0; i < resultLinear.length(); i++) {
resultLinear.putScalar(i,otherLinear.getDouble(i) + linear.getDouble(i));
}
}
return result;
}
/**
* Returns the normmax along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override
public INDArray normmax(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.normmax(dimension),dimension);
return doDimensionWise(
DimensionFunctions.normmax(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
/**
* Reverse division
*
* @param other the matrix to divide from
* @return
*/
@Override
public INDArray rdiv(INDArray other) {
return dup().rdivi(other);
}
/**
* Reverse divsion (in place)
*
* @param other
* @return
*/
@Override
public INDArray rdivi(INDArray other) {
return rdivi(other,this);
}
/**
* Reverse division
*
* @param other the matrix to subtract from
* @param result the result ndarray
* @return
*/
@Override
public INDArray rdiv(INDArray other, INDArray result) {
return dup().rdivi(other,result);
}
/**
* Reverse division (in-place)
*
* @param other the other ndarray to subtract
* @param result the result ndarray
* @return the ndarray with the operation applied
*/
@Override
public INDArray rdivi(INDArray other, INDArray result) {
return other.divi(this, result);
}
/**
* Reverse subtraction
*
* @param other the matrix to subtract from
* @param result the result ndarray
* @return
*/
@Override
public INDArray rsub(INDArray other, INDArray result) {
return dup().rsubi(other,result);
}
/**
* @param other
* @return
*/
@Override
public INDArray rsub(INDArray other) {
return dup().rsubi(other);
}
/**
* @param other
* @return
*/
@Override
public INDArray rsubi(INDArray other) {
return rsubi(other,this);
}
/**
* Reverse subtraction (in-place)
*
* @param other the other ndarray to subtract
* @param result the result ndarray
* @return the ndarray with the operation applied
*/
@Override
public INDArray rsubi(INDArray other, INDArray result) {
return other.subi(this, result);
}
/**
* Set the value of the ndarray to the specified value
*
* @param value the value to assign
* @return the ndarray with the values
*/
@Override
public INDArray assign(Number value) {
INDArray one = linearView();
for(int i = 0; i < one.length(); i++)
one.put(i, Nd4j.scalar(value.doubleValue()));
return this;
}
@Override
public int linearIndex(int i) {
int realStride = stride[0];
int idx = offset + i * realStride;
if(data != null && idx >= data.length())
throw new IllegalArgumentException("Illegal index " + idx + " derived from " + i + " with offset of " + offset + " and stride of " + realStride);
return idx;
}
/**
* Returns the specified slice of this matrix.
* In matlab, this would be equivalent to (given a 2 x 2 x 2):
* A(:,:,x) where x is the slice you want to return.
*
* The slice is always relative to the final dimension of the matrix.
*
* @param slice the slice to return
* @return the specified slice of this matrix
*/
@Override
public INDArray slice(int slice) {
if (shape.length == 0)
throw new IllegalArgumentException("Can't slice a 0-d NDArray");
//slice of a vector is a scalar
else if (shape.length == 1) {
if(size(0) == 1)
return Nd4j.create(data, ArrayUtil.empty(), ArrayUtil.empty(), offset + slice);
else
return Nd4j.create(data, ArrayUtil.empty(), ArrayUtil.empty(), offset + slice * stride[0]);
}
//slice of a matrix is a vector
else if (shape.length == 2) {
if(size(0) == 1) {
INDArray slice2 = Nd4j.create(
data,
ArrayUtil.of(shape[1]),
Arrays.copyOfRange(stride, 1, stride.length),
offset + slice, ordering
);
return slice2;
}
else {
INDArray slice2 = Nd4j.create(
data,
ArrayUtil.of(shape[1]),
Arrays.copyOfRange(stride, 1, stride.length),
offset + slice * stride[0], ordering
);
return slice2;
}
}
else {
int offset = this.offset + (slice * stride[0]);
if(size(0) == 1) {
INDArray slice2 = Nd4j.create(data,
Arrays.copyOfRange(shape, 1, shape.length),
Arrays.copyOfRange(stride, 1, stride.length),
offset, ordering);
return slice2;
}
else {
INDArray slice2 = Nd4j.create(data,
Arrays.copyOfRange(shape, 1, shape.length),
Arrays.copyOfRange(stride, 1, stride.length),
offset, ordering);
return slice2;
}
}
}
/**
* Returns the slice of this from the specified dimension
* @param slice the dimension to return from
* @param dimension the dimension of the slice to return
* @return the slice of this matrix from the specified dimension
* and dimension
*/
@Override
public INDArray slice(int slice, int dimension) {
if(shape.length == 2) {
//rows
if(dimension == 1)
return getRow(slice);
else if(dimension == 0)
return getColumn(slice);
else throw new IllegalAccessError("Illegal dimension for matrix");
}
if (slice == shape.length - 1)
return slice(dimension);
INDArray slice2 = Nd4j.create(data,
ArrayUtil.removeIndex(shape, dimension),
ArrayUtil.removeIndex(stride, dimension),
offset + slice * stride[dimension], ordering);
return slice2;
}
/**
* Fetch a particular number on a multi dimensional scale.
* @param indexes the indexes to getFromOrigin a number from
* @return the number at the specified indices
*/
@Override
public INDArray getScalar(int... indexes) {
int ix = offset;
for (int i = 0; i < indexes.length; i++) {
ix += indexes[i] * stride[i];
}
if(ix >= data.length())
throw new IllegalArgumentException("Illegal index " + Arrays.toString(indexes));
return Nd4j.scalar(data.getDouble(ix));
}
@Override
public INDArray rdiv(Number n) {
return dup().rdivi(n);
}
@Override
public INDArray rdivi(Number n) {
return rdivi(Nd4j.valueArrayOf(shape(), n.doubleValue()));
}
@Override
public INDArray rsub(Number n) {
return dup().rsubi(n);
}
@Override
public INDArray rsubi(Number n) {
return rsubi(Nd4j.valueArrayOf(shape(), n.doubleValue()));
}
@Override
public INDArray div(Number n) {
return dup().divi(n);
}
@Override
public INDArray divi(Number n) {
return divi(n,this);
}
@Override
public INDArray mul(Number n) {
return dup().muli(n);
}
@Override
public INDArray muli(Number n) {
return muli(n,this);
}
@Override
public INDArray sub(Number n) {
return dup().subi(n);
}
@Override
public INDArray subi(Number n) {
return subi(n,this);
}
@Override
public INDArray add(Number n) {
return dup().addi(n);
}
@Override
public INDArray addi(Number n) {
return addi(n,this);
}
/**
* Replicate and tile array to fill out to the given shape
*
* @param shape the new shape of this ndarray
* @return the shape to fill out to
*/
@Override
public INDArray repmat(int[] shape) {
int[] newShape = new int[shape.length];
assert shape.length <= newShape.length : "Illegal shape: The passed in shape must be <= the current shape length";
int[] oldShape = isRowVector() ? new int[]{1,this.shape[0]} : Arrays.copyOf(this.shape,2);
for(int i = 0; i < newShape.length; i++) {
if(i < this.shape.length)
newShape[i] = oldShape[i] * shape[i];
else
newShape[i] = oldShape[i];
}
INDArray result = Nd4j.create(newShape);
//nd copy
if(isScalar()) {
for(int i = 0; i < result.length(); i++) {
result.put(i,getScalar(0));
}
}
else if(isRowVector()) {
if(Shape.isColumnVectorShape(newShape))
return transpose();
else if(Shape.isMatrix(newShape)) {
INDArray ret = Nd4j.create(newShape);
for(int i = 0; i < ret.rows(); i++) {
ret.putRow(i,this);
}
return ret;
}
}
else if(isMatrix()) {
for (int c = 0; c < shape()[1]; c++) {
for (int r = 0; r < shape()[0]; r++) {
for (int i = 0; i < rows(); i++) {
for (int j = 0; j < columns(); j++) {
result.put(r * rows() + i, c * columns() + j, getScalar(i, j));
}
}
}
}
}
else {
int[] sliceRepmat = ArrayUtil.removeIndex(shape,0);
for(int i = 0; i < result.slices(); i++) {
result.putSlice(i,repmat(sliceRepmat));
}
}
INDArray ret = result;
return ret;
}
/**
* Insert a row in to this array
* Will throw an exception if this
* ndarray is not a matrix
*
* @param row the row insert into
* @param toPut the row to insert
* @return this
*/
@Override
public INDArray putRow(int row, INDArray toPut) {
assert toPut.isVector() && toPut.length() == columns : "Illegal length for row " + toPut.length() + " should have been " + columns;
INDArray r = getRow(row);
for(int i = 0; i < r.length(); i++)
r.putScalar(i,toPut.getDouble(i));
return this;
}
/**
* Insert a column in to this array
* Will throw an exception if this
* ndarray is not a matrix
*
* @param column the column to insert
* @param toPut the array to put
* @return this
*/
@Override
public INDArray putColumn(int column, INDArray toPut) {
assert toPut.isVector() && toPut.length() == rows() : "Illegal length for row " + toPut.length() + " should have been " + rows();
INDArray r = getColumn(column);
for(int i = 0; i < r.length(); i++)
r.putScalar(i,toPut.getDouble(i));
return this;
}
@Override
public E getElement(int i) {
int idx = linearIndex(i);
if(idx < 0)
throw new IllegalStateException("Illegal index " + i);
return data.getElement(idx);
}
@Override
public E getElement(int i, int j) {
int idx = index(i,j);
return data.getElement(idx);
}
@Override
public double getDouble(int i) {
int idx = linearIndex(i);
if(idx < 0)
throw new IllegalStateException("Illegal index " + i);
return data.getDouble(idx);
}
@Override
public double getDouble(int i, int j) {
int idx = index(i,j);
return data.getDouble(idx);
}
@Override
public float getFloat(int i) {
int idx = linearIndex(i);
if(idx < 0)
throw new IllegalStateException("Illegal index " + i);
return data.getFloat(idx);
}
@Override
public float getFloat(int i, int j) {
int idx = index(i,j);
return data.getFloat(idx);
}
/**
* Mainly an internal method (public for testing)
* for given an offset and stride, and index,
* calculating the beginning index of a query given indices
* @param offset the desired offset
* @param stride the desired stride
* @param indexes the desired indexes to test on
* @return the index for a query given stride and offset
*/
public static int getIndex(int offset,int[] stride,int...indexes) {
if(stride.length > indexes.length)
throw new IllegalArgumentException("Invalid number of items in stride array: should be <= number of indexes");
int ix = offset;
for (int i = 0; i < indexes.length; i++) {
ix += indexes[i] * stride[i];
}
return ix;
}
/**
* Return transposed copy of this matrix.
*/
@Override
public INDArray transpose() {
if(isRowVector())
return Nd4j.create(data, new int[]{shape[0], 1}, offset);
else if(isColumnVector())
return Nd4j.create(data, new int[]{shape[0]}, offset);
if(isMatrix()) {
INDArray reverse = Nd4j.create(new int[]{shape[1],shape[0]});
for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {
reverse.putScalar(new int[]{j, i}, getDouble(i, j));
}
}
return reverse;
}
INDArray ret = permute(ArrayUtil.range(shape.length -1,-1));
return ret;
}
/**
* Reshape the ndarray in to the specified dimensions,
* possible errors being thrown for invalid shapes
* @param shape
* @return
*/
@Override
public INDArray reshape(int[] shape) {
assert ArrayUtil.prod(shape) == ArrayUtil.prod(this.shape()) : "Illegal reshape must be of same length as data";
return newShape(shape,ordering);
}
@Override
public void checkDimensions(INDArray other) {
assert Arrays.equals(shape,other.shape()) : " Other array should have been shape: " + Arrays.toString(shape) + " but was " + Arrays.toString(other.shape());
assert Arrays.equals(stride,other.stride()) : " Other array should have been stride: " + Arrays.toString(stride) + " but was " + Arrays.toString(other.stride());
assert offset == other.offset() : "Offset of this array is " + offset + " but other was " + other.offset();
}
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @return the product along the specified dimension
*/
@Override
public INDArray prod(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.prod(dimension),dimension);
return doDimensionWise(
DimensionFunctions.prod(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray mean(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.mean(dimension),dimension);
return doDimensionWise(
DimensionFunctions.mean(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
/**
* Returns the overall variance of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray var(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.var(dimension),dimension);
return doDimensionWise(
DimensionFunctions.var(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
/**
* Returns the overall max of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray max(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.max(dimension),dimension);
return doDimensionWise(
DimensionFunctions.max(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray min(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.min(dimension),dimension);
return doDimensionWise(
DimensionFunctions.min(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
/**
* Returns the sum along the last dimension of this ndarray
*
* @param dimension the dimension to getScalar the sum along
* @return the sum along the specified dimension of this ndarray
*/
@Override
public INDArray sum(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.sum(dimension),dimension);
return doDimensionWise(
DimensionFunctions.sum(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
/**
* Returns the norm1 along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override
public INDArray norm1(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.norm1(dimension),dimension);
return doDimensionWise(
DimensionFunctions.norm1(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
/**
* Standard deviation of an ndarray along a dimension
*
* @param dimension the dimension to getScalar the std along
* @return the standard deviation along a particular dimension
*/
@Override
public INDArray std(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.std(dimension),dimension);
return doDimensionWise(
DimensionFunctions.std(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
/**
* Returns the norm2 along the specified dimension
*
* @param dimension the dimension to getScalar the norm2 along
* @return the norm2 along the specified dimension
*/
@Override
public INDArray norm2(int dimension) {
Triple pair = getOp(new AtomicInteger(0),DimensionFunctions.norm2(dimension),dimension);
return doDimensionWise(
DimensionFunctions.norm2(),
pair.getLeft(),pair.getMiddle(),pair.getRight(),dimension,false);
}
private Triple getOp(final AtomicInteger i,final Function func,int dimension) {
int[] shape = shape().length == 1 || dimension == Integer.MAX_VALUE ? new int[] {1} : ArrayUtil.removeIndex(shape(),dimension);
final INDArray put = Nd4j.create(new int[]{ArrayUtil.prod(shape)});
SliceOp op = new SliceOp() {
@Override
public void operate(DimensionSlice nd) {
INDArray arr2 = (INDArray) nd.getResult();
put.put(i.get(),func.apply(arr2));
i.incrementAndGet();
}
/**
* Operates on an ndarray slice
*
* @param nd the result to operate on
*/
@Override
public void operate(INDArray nd) {
put.put(i.get(),func.apply(nd));
i.incrementAndGet();
}
};
return Triple.of(op,put,shape);
}
/**
* Number of columns (shape[1]), throws an exception when
* called when not 2d
* @return the number of columns in the array (only 2d)
*/
@Override
public int columns() {
if(isMatrix()) {
if (shape().length > 2)
return Shape.squeeze(shape)[1];
else if (shape().length == 2)
return shape[1];
}
if(isVector()) {
if(isColumnVector())
return 1;
else
return shape[0];
}
throw new IllegalStateException("Unable to getFromOrigin number of of rows for a non 2d matrix");
}
/**
* Returns the number of rows
* in the array (only 2d) throws an exception when
* called when not 2d
* @return the number of rows in the matrix
*/
@Override
public int rows() {
if(isMatrix()) {
if (shape().length > 2)
return Shape.squeeze(shape)[0];
else if (shape().length == 2)
return shape[0];
}
else if(isVector()) {
if(isRowVector())
return 1;
else
return shape[0];
}
throw new IllegalStateException("Unable to getFromOrigin number of of rows for a non 2d matrix");
}
protected INDArray doDimensionWise(Function baseCase,SliceOp sliceOp,INDArray arr,int[] newShape,int dimension,boolean modify) {
if(dimension == Integer.MAX_VALUE || isVector())
return baseCase.apply(this.linearView());
else {
iterateOverDimension(dimension, sliceOp,modify);
return arr.reshape(newShape);
}
}
/**
* Flattens the array for linear indexing
* @return the flattened version of this array
*/
@Override
public INDArray ravel() {
INDArray ret = Nd4j.create(length, ordering);
int dimension = shape.length == 2 ? 1 : shape.length;
int count = 0;
for(int i = 0; i < vectorsAlongDimension(dimension); i++) {
INDArray vec = vectorAlongDimension(i,dimension);
for(int j = 0; j < vec.length(); j++) {
ret.putScalar(count++,vec.getDouble(j));
}
}
return ret;
}
/**
* Flattens the array for linear indexing
* @return the flattened version of this array
*/
@Override
public void sliceVectors(List list) {
if(isVector())
list.add(this);
else {
for(int i = 0; i < slices(); i++) {
slice(i).sliceVectors(list);
}
}
}
/**
* Reshape the matrix. Number of elements must not change.
*
* @param newRows
* @param newColumns
*/
@Override
public INDArray reshape(int newRows, int newColumns) {
assert newRows * newColumns == length : "Illegal new shape " + newRows + " x " + newColumns;
return reshape(new int[]{newRows,newColumns});
}
/**
* Get the specified column
*
* @param c
*/
@Override
public INDArray getColumn(int c) {
if(shape.length == 2) {
if(ordering == NDArrayFactory.C) {
INDArray ret = Nd4j.create(
data,
new int[]{shape[0], 1},
new int[]{stride[0], 1},
offset + c, ordering
);
return ret;
}
else {
INDArray ret = Nd4j.create(
data,
new int[]{shape[0],1},
new int[]{stride[0],1},
offset + c * rows(), ordering
);
return ret;
}
}
else if(isColumnVector() && c == 0)
return this;
else
throw new IllegalArgumentException("Unable to getFloat scalar column of non 2d matrix");
}
/**
* Get whole rows from the passed indices.
*
* @param rindices
*/
@Override
public INDArray getRows(int[] rindices) {
INDArray rows = Nd4j.create(rindices.length, columns());
for(int i = 0; i < rindices.length; i++) {
rows.putRow(i,getRow(rindices[i]));
}
return rows;
}
/**
* Returns a subset of this array based on the specified
* indexes
*
* @param indexes the indexes in to the array
* @return a view of the array with the specified indices
*/
@Override
public INDArray get(NDArrayIndex... indexes) {
//fill in to match the rest of the dimensions: aka grab all the content
//in the dimensions not filled in
//also prune indices greater than the shape to be the shape instead
indexes = Indices.adjustIndices(shape(),indexes);
int[] offsets = Indices.offsets(indexes);
int[] shape = Indices.shape(shape(),indexes);
//no stride will help here, need to do manually
if(!Indices.isContiguous(indexes)) {
INDArray ret = Nd4j.create(shape);
if(ret.isVector() && isVector()) {
int[] indices = indexes[0].indices();
for(int i = 0; i < ret.length(); i++) {
ret.putScalar(i,getDouble(indices[i]));
}
return ret;
}
if(!ret.isVector()) {
//overrides when shouldn't
for(int i = 0; i < ret.slices(); i++) {
INDArray putSlice = slice(i).get(Arrays.copyOfRange(indexes,1,indexes.length));
ret.putSlice(i,putSlice);
}
}
else {
INDArray putSlice = slice(0).get(Arrays.copyOfRange(indexes,1,indexes.length));
ret.putSlice(0,putSlice);
}
return ret;
}
if(ArrayUtil.prod(shape) > length())
return this;
int[] strides = null;
strides = ArrayUtil.copy(stride());
if(offsets.length != shape.length)
offsets = Arrays.copyOfRange(offsets,0,shape.length);
if(strides.length != shape.length)
strides = Arrays.copyOfRange(strides,0,shape.length);
return subArray(offsets,shape,strides);
}
/**
* Get whole columns from the passed indices.
*
* @param cindices
*/
@Override
public INDArray getColumns(int[] cindices) {
INDArray rows = Nd4j.create(rows(), cindices.length);
for(int i = 0; i < cindices.length; i++) {
rows.putColumn(i,getColumn(cindices[i]));
}
return rows;
}
/**
* Get a copy of a row.
*
* @param r
*/
@Override
public INDArray getRow(int r) {
if(shape.length == 2) {
if(ordering == NDArrayFactory.C) {
INDArray ret = Nd4j.create(
data,
new int[]{shape[1]},
new int[]{stride[1]},
offset + r * columns(),
ordering
);
return ret;
}
else {
INDArray ret = Nd4j.create(
data,
new int[]{shape[1]},
new int[]{stride[1]},
offset + r,
ordering
);
return ret;
}
}
else if(isRowVector() && r == 0)
return this;
else
throw new IllegalArgumentException("Unable to getFloat row of non 2d matrix");
}
/**
* Compare two matrices. Returns true if and only if other is also a
* DoubleMatrix which has the same size and the maximal absolute
* difference in matrix elements is smaller than 1e-6.
*
* @param o
*/
@Override
public boolean equals(Object o) {
INDArray n = null;
if(!(o instanceof INDArray))
return false;
if(n == null)
n = (INDArray) o;
//epsilon equals
if(isScalar() && n.isScalar()) {
if(data.dataType().equals(DataBuffer.FLOAT)) {
double val = getDouble(0);
double val2 = n.getDouble(0);
return Math.abs(val - val2) < 1e-6;
}
else {
double val = getDouble(0);
double val2 = n.getDouble(0);
return Math.abs(val - val2) < 1e-6;
}
}
else if(isVector() && n.isVector()) {
for(int i = 0; i < length; i++) {
if(data.dataType().equals(DataBuffer.FLOAT)) {
double curr = getDouble(i);
double comp = n.getDouble(i);
if(Math.abs(curr - comp) > 1e-3)
return false;
}
else {
double curr = getDouble(i);
double comp = n.getDouble(i);
if(Math.abs(curr - comp) > 1e-3)
return false;
}
}
return true;
}
if(!Shape.shapeEquals(shape(),n.shape()))
return false;
if(slices() != n.slices())
return false;
for (int i = 0; i < slices(); i++) {
INDArray slice = slice(i);
INDArray nSlice = n.slice(i);
if (!slice.equals(nSlice))
return false;
}
return true;
}
/**
* Returns the shape(dimensions) of this array
* @return the shape of this matrix
*/
public int[] shape() {
return shape;
}
/**
* Returns the stride(indices along the linear index for which each slice is accessed) of this array
* @return the stride of this array
*/
@Override
public int[] stride() {
return stride;
}
@Override
public int offset() {
return offset;
}
@Override
public char ordering() {
return ordering;
}
/**
* Returns the size of this array
* along a particular dimension
* @param dimension the dimension to return from
* @return the shape of the specified dimension
*/
@Override
public int size(int dimension) {
if(isScalar()) {
if(dimension == 0)
return length;
else
throw new IllegalArgumentException("Illegal dimension for scalar " + dimension);
}
else if(isVector()) {
if(dimension == 0)
return length;
else if(dimension == 1)
return 1;
}
return shape[dimension];
}
/**
* Returns the total number of elements in the ndarray
*
* @return the number of elements in the ndarray
*/
@Override
public int length() {
return length;
}
/**
* Broadcasts this ndarray to be the specified shape
*
* @param shape the new shape of this ndarray
* @return the broadcasted ndarray
*/
@Override
public INDArray broadcast(int[] shape) {
int dims = this.shape.length;
int targetDimensions = shape.length;
if (targetDimensions < dims) {
throw new IllegalArgumentException("Invalid shape to broad cast " + Arrays.toString(shape));
}
else if(isColumnVector()) {
int n = shape[1];
INDArray ret = Nd4j.create(shape);
for(int i = 0; i < n; i++) {
ret.putColumn(i,this);
}
return ret;
}
//edge case where column vector looks like a matrix
else if (dims == targetDimensions && !isColumnVector()) {
if(isVector())
return this;
if (Shape.shapeEquals(shape, this.shape()))
return this;
throw new IllegalArgumentException("Invalid shape to broad cast " + Arrays.toString(shape));
}
else {
int n = shape[0];
INDArray s = broadcast(Arrays.copyOfRange(shape, 1, targetDimensions));
return Nd4j.repeat(s,n);
}
}
/**
* Dimshuffle: an extension of permute that adds the ability
* to broadcast various dimensions.
*
* See theano for more examples.
* This will only accept integers and xs.
*
* An x indicates a dimension should be broadcasted rather than permuted.
*
* @param rearrange the dimensions to swap to
* @return the newly permuted array
*/
@Override
public INDArray dimShuffle(Object[] rearrange,int[] newOrder,boolean[] broadCastable) {
assert broadCastable.length == shape.length : "The broadcastable dimensions must be the same length as the current shape";
boolean broadcast = false;
Set