Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg.api.ndarray;
import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import net.ericaro.neoitertools.Generator;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.graph.ByteOrder;
import org.nd4j.graph.FlatArray;
import org.nd4j.linalg.api.blas.BlasBufferUtil;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.instrumentation.Instrumentation;
import org.nd4j.linalg.api.iter.FirstAxisIterator;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.impl.accum.*;
import org.nd4j.linalg.api.ops.impl.accum.Max;
import org.nd4j.linalg.api.ops.impl.accum.Min;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.accum.distances.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.broadcast.*;
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.scalar.*;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.*;
import org.nd4j.linalg.api.ops.impl.shape.Tile;
import org.nd4j.linalg.api.ops.impl.transforms.Assign;
import org.nd4j.linalg.api.ops.impl.transforms.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.Negative;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.Nd4jNoSuchWorkspaceException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.*;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.string.NDArrayStrings;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.nd4j.linalg.util.LongUtils;
import org.nd4j.linalg.util.NDArrayMath;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.IntBuffer;
import java.util.*;
import static org.nd4j.linalg.factory.Nd4j.*;
/**
* NDArray: (think numpy)
*
* A few things of note.
*
* An NDArray can have any number of dimensions.
*
* An NDArray is accessed via strides.
*
* Strides are how to index over
* a contiguous block of data.
*
* This block of data has 2 orders(as of right now):
* fortran and c
*
* @author Adam Gibson
*/
@Slf4j
public abstract class BaseNDArray implements INDArray, Iterable {
private static final long serialVersionUID = 3285982317165542614L;
protected transient volatile DataBuffer shapeInformation;
protected transient volatile DataBuffer data;
//protected transient DataBuffer shape;
//protected transient DataBuffer stride;
protected transient boolean compressed = false;
// this field holds jvm copy of shapeInfo
protected int[] javaShapeInformation;
//Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over
private static final int[][] tadFinalPermuteDimensions;
static {
tadFinalPermuteDimensions = new int[32][0];
tadFinalPermuteDimensions[1] = new int[] {1, 0}; //Edge case for 1d tensors: selectively apply to column vectors
for (int i = 2; i < 32; i++) {
tadFinalPermuteDimensions[i] = new int[i];
for (int k = i - 1, j = 0; k >= 0; k--, j++)
tadFinalPermuteDimensions[i][j] = k;
}
}
public BaseNDArray() {}
/**
* Returns true if this array is compressed, and false otherwise
* @return
*/
@Override
public boolean isCompressed() {
return compressed;
}
/**
* This method marks INDArray instance as compressed
* PLEASE NOTE: Do not use this method unless you 100% have to
*
* @param reallyCompressed
*/
@Override
public void markAsCompressed(boolean reallyCompressed) {
this.compressed = reallyCompressed;
}
/**
*
* @param buffer
*/
public BaseNDArray(DataBuffer buffer) {
this.data = buffer;
if (buffer.length() >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE");
int[] shape = {1, (int) buffer.length()};
int[] stride = Nd4j.getStrides(shape);
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, 1, Nd4j.order()));
init(shape, stride);
}
/**
*
* @param buffer
* @param shape
* @param stride
* @param offset
* @param ordering
*/
public BaseNDArray(DataBuffer buffer, int[] shape, int[] stride, long offset, char ordering) {
this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, ArrayUtil.prodLong(shape)) : buffer;
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, offset,
Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering));
init(shape, stride);
// Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));
}
/**
* Initialize the ndarray as a matrix
* with the given data (indices preserved)
* @param data
*/
public BaseNDArray(double[][] data) {
this(data, Nd4j.order());
}
/**
*
* @param data
* @param ordering
*/
public BaseNDArray(double[][] data, char ordering) {
this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)),
new int[] {data.length, data[0].length},
Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering);
for (int r = 0; r < rows(); r++) {
assert (data[r].length == columns());
}
}
/**
* Create with the specified shape and buffer
*
* @param shape the shape
* @param buffer the buffer
*/
public BaseNDArray(int[] shape, DataBuffer buffer) {
this.data = buffer;
init(shape, Nd4j.getStrides(shape));
}
/**
* Create this ndarray with the given data and shape and 0 offset
*
* @param data the data to use
* @param shape the shape of the ndarray
*/
public BaseNDArray(float[] data, int[] shape, char ordering) {
this(data, shape, 0, ordering);
}
/**
* @param data the data to use
* @param shape the shape of the ndarray
* @param offset the desired offset
* @param ordering the ordering of the ndarray
*/
public BaseNDArray(float[] data, int[] shape, long offset, char ordering) {
this(data, shape, Nd4j.getStrides(shape, ordering), offset);
}
/**
* Construct an ndarray of the specified shape
* with an empty data array
*
* @param shape the shape of the ndarray
* @param stride the stride of the ndarray
* @param offset the desired offset
* @param ordering the ordering of the ndarray
*/
public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) {
this(Nd4j.createBuffer(ArrayUtil.prodLong(shape)), shape, stride, offset, ordering);
}
/**
* Construct an ndarray of the specified shape.
*
* @param shape the shape of the ndarray
* @param stride the stride of the ndarray
* @param offset the desired offset
* @param ordering the ordering of the ndarray
* @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't.
*/
public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) {
this(Nd4j.createBuffer(ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering);
}
/**
* Create the ndarray with
* the specified shape and stride and an offset of 0
*
* @param shape the shape of the ndarray
* @param stride the stride of the ndarray
* @param ordering the ordering of the ndarray
*/
public BaseNDArray(int[] shape, int[] stride, char ordering) {
this(shape, stride, 0, ordering);
}
/**
*
* @param shape
* @param offset
* @param ordering
*/
public BaseNDArray(int[] shape, long offset, char ordering) {
this(shape, Nd4j.getStrides(shape, ordering), offset, ordering);
}
/**
* Create an ndarray
* with the given shape
* @param shape
*/
public BaseNDArray(int[] shape) {
this(shape, 0, Nd4j.order());
}
/**
* Creates a new n times mDoubleMatrix.
*
* @param newRows the number of rows (n) of the new matrix.
* @param newColumns the number of columns (m) of the new matrix.
*/
public BaseNDArray(int newRows, int newColumns, char ordering) {
this.data = Nd4j.createBuffer((long) newRows * newColumns);
int[] shape = new int[] {newRows, newColumns};
int[] stride = Nd4j.getStrides(shape, ordering);
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0,
Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering));
init(shape, stride);
// Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));
}
/**
* Create an ndarray from the specified slices.
* This will go through and merge all of the
* data from each slice in to one ndarray
* which will then take the specified shape
*
* @param slices the slices to merge
* @param shape the shape of the ndarray
*/
public BaseNDArray(List slices, int[] shape, char ordering) {
this(slices, shape, Nd4j.getStrides(shape, ordering), ordering);
}
/**
* Create an ndarray from the specified slices.
* This will go through and merge all of the
* data from each slice in to one ndarray
* which will then take the specified shape
*
* @param slices the slices to merge
* @param shape the shape of the ndarray
*/
public BaseNDArray(List slices, int[] shape, int[] stride, char ordering) {
DataBuffer ret = slices.get(0).data().dataType() == (DataBuffer.Type.FLOAT)
? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)])
: Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]);
this.data = ret;
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0,
Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering));
init(shape, stride);
// Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));
if (slices.get(0).isScalar()) {
for (int i = 0; i < length(); i++) {
putScalar(i, slices.get(i).getDouble(0));
}
} else {
for (int i = 0; i < slices(); i++) {
putSlice(i, slices.get(i));
}
}
}
/**
*
* @param data
* @param shape
* @param stride
* @param ordering
*/
public BaseNDArray(float[] data, int[] shape, int[] stride, char ordering) {
this(data, shape, stride, 0, ordering);
}
/**
*
* @param data
* @param shape
* @param stride
* @param offset
* @param ordering
*/
public BaseNDArray(float[] data, int[] shape, int[] stride, long offset, char ordering) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, offset,
Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering));
if (data != null && data.length > 0) {
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
this.data = internalCreateBuffer(data, offset);
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, data.length * Nd4j.sizeOfDataType(), MemcpyDirection.HOST_TO_HOST);
if (offset >= data.length)
throw new IllegalArgumentException("invalid offset: must be < data.length");
}
init(shape, stride);
// Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape,stride,ordering == 'f'));
}
/**
*
* @param data
* @param shape
* @param stride
* @param offset
*/
public BaseNDArray(DataBuffer data, int[] shape, int[] stride, long offset) {
this.data = Nd4j.createBuffer(data, offset, ArrayUtil.prodLong(shape));
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, offset, Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f'), Nd4j.order()));
init(shape, stride);
// Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f'));
}
/**
*
* @param data
* @param shape
* @param strides
*/
public BaseNDArray(int[] data, int[] shape, int[] strides) {
this(internalCreateBuffer(data), shape, strides);
}
/**
*
* @param data
* @param shape
*/
public BaseNDArray(DataBuffer data, int[] shape) {
this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order());
}
/**
*
* @param buffer
* @param shape
* @param offset
*/
public BaseNDArray(DataBuffer buffer, int[] shape, long offset) {
this(Nd4j.createBuffer(buffer, offset, ArrayUtil.prodLong(shape)), shape, Nd4j.getStrides(shape), offset,
Nd4j.order());
}
/**
*
* @param buffer
* @param shape
* @param ordering
*/
public BaseNDArray(DataBuffer buffer, int[] shape, char ordering) {
this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering);
}
/**
*
* @param data
* @param shape
* @param ordering
*/
public BaseNDArray(double[] data, int[] shape, char ordering) {
this(internalCreateBuffer(data), shape, ordering);
}
/**
*
* @param data
* @param shape
* @param stride
* @param offset
* @param ordering
*/
public BaseNDArray(double[] data, int[] shape, int[] stride, long offset, char ordering) {
this(internalCreateBuffer(data, offset), shape, stride, offset, ordering);
}
/**
*
* @param data
* @param order
*/
public BaseNDArray(float[] data, char order) {
this(internalCreateBuffer(data), order);
}
protected static DataBuffer internalCreateBuffer(float[] data) {
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
val buffer = Nd4j.createBuffer(data);
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(), MemcpyDirection.HOST_TO_HOST);
return buffer;
}
protected static DataBuffer internalCreateBuffer(double[] data) {
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
val buffer = Nd4j.createBuffer(data);
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(), MemcpyDirection.HOST_TO_HOST);
return buffer;
}
protected static DataBuffer internalCreateBuffer(int[] data) {
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
val buffer = Nd4j.createBuffer(data);
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(), MemcpyDirection.HOST_TO_HOST);
return buffer;
}
protected static DataBuffer internalCreateBuffer(float[] data, long offset) {
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
val buffer = Nd4j.createBuffer(data, offset);
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(), MemcpyDirection.HOST_TO_HOST);
return buffer;
}
protected static DataBuffer internalCreateBuffer(double[] data, long offset) {
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
val buffer = Nd4j.createBuffer(data, offset);
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(), MemcpyDirection.HOST_TO_HOST);
return buffer;
}
protected static DataBuffer internalCreateBuffer(int[] data, long offset) {
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
val buffer = Nd4j.createBuffer(data, offset);
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(), MemcpyDirection.HOST_TO_HOST);
return buffer;
}
/**
*
* @param floatBuffer
* @param order
*/
public BaseNDArray(DataBuffer floatBuffer, char order) {
this(floatBuffer, new int[] {(int) floatBuffer.length()},
Nd4j.getStrides(new int[] {(int) floatBuffer.length()}, order), 0, order);
if (floatBuffer.length() >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE");
}
/**
*
* @param buffer
* @param shape
* @param strides
*/
public BaseNDArray(DataBuffer buffer, int[] shape, int[] strides) {
this(buffer, shape, strides, 0, Nd4j.order());
}
/**
* Create this ndarray with the given data and shape and 0 offset
*
* @param data the data to use
* @param shape the shape of the ndarray
*/
public BaseNDArray(float[] data, int[] shape) {
this(data, shape, 0);
}
/**
*
* @param data
* @param shape
* @param offset
*/
public BaseNDArray(float[] data, int[] shape, long offset) {
this(data, shape, offset, Nd4j.order());
}
/**
* Construct an ndarray of the specified shape
* with an empty data array
*
* @param shape the shape of the ndarray
* @param stride the stride of the ndarray
* @param offset the desired offset
*/
public BaseNDArray(int[] shape, int[] stride, long offset) {
this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order());
}
/**
* Create the ndarray with
* the specified shape and stride and an offset of 0
*
* @param shape the shape of the ndarray
* @param stride the stride of the ndarray
*/
public BaseNDArray(int[] shape, int[] stride) {
this(shape, stride, 0);
}
/**
*
* @param shape
* @param offset
*/
public BaseNDArray(int[] shape, long offset) {
this(shape, Nd4j.getStrides(shape), offset);
}
/**
*
* @param shape
* @param ordering
*/
public BaseNDArray(int[] shape, char ordering) {
this(shape, 0, ordering);
}
/**
* Creates a new n times mDoubleMatrix.
*
* @param newRows the number of rows (n) of the new matrix.
* @param newColumns the number of columns (m) of the new matrix.
*/
public BaseNDArray(int newRows, int newColumns) {
this(newRows, newColumns, Nd4j.order());
}
/**
* Create an ndarray from the specified slices.
* This will go through and merge all of the
* data from each slice in to one ndarray
* which will then take the specified shape
*
* @param slices the slices to merge
* @param shape the shape of the ndarray
*/
public BaseNDArray(List slices, int[] shape) {
this(slices, shape, Nd4j.order());
}
/**
* Create an ndarray from the specified slices.
* This will go through and merge all of the
* data from each slice in to one ndarray
* which will then take the specified shape
*
* @param slices the slices to merge
* @param shape the shape of the ndarray
*/
public BaseNDArray(List slices, int[] shape, int[] stride) {
this(slices, shape, stride, Nd4j.order());
}
/**
*
* @param data
* @param shape
* @param stride
*/
public BaseNDArray(float[] data, int[] shape, int[] stride) {
this(data, shape, stride, Nd4j.order());
}
/**
*
* @param data
* @param shape
* @param stride
* @param offset
*/
public BaseNDArray(float[] data, int[] shape, int[] stride, long offset) {
this(data, shape, stride, offset, Nd4j.order());
}
/**
*
* @param data
*/
public BaseNDArray(float[] data) {
this(Nd4j.createBuffer(data));
}
/**
* Initialize the ndarray
* with the given data
* @param data
*/
public BaseNDArray(float[][] data) {
this(data, Nd4j.order());
}
/**
*
* @param data
* @param ordering
*/
public BaseNDArray(float[][] data, char ordering) {
this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)),
new int[] {data.length, data[0].length},
Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering);
for (int r = 0; r < rows(); r++) {
assert (data[r].length == columns());
}
}
/**
* Constructor for stride and offset
*
* @param buffer
* @param shape
* @param offset
* @param ordering
*/
public BaseNDArray(DataBuffer buffer, int[] shape, long offset, char ordering) {
this(buffer, shape, Nd4j.getStrides(shape, ordering), offset, ordering);
}
public BaseNDArray(double[] data, int[] shape, int[] stride, long offset) {
this(internalCreateBuffer(data), shape, stride, offset);
}
@Override
@Deprecated
public void setWrapAround(boolean wrapAround) {
throw new UnsupportedOperationException();
}
@Override
@Deprecated
public boolean isWrapAround() {
throw new UnsupportedOperationException();
}
/**
* Returns whether the ndarray is valid or not
* @return true if the ndarray is valid
* false otherwise
*/
@Deprecated
public boolean isValid() {
try {
linearIndex(length() - 1);
} catch (Exception e) {
return false;
}
return true;
}
@Override
@Deprecated
public INDArray linearViewColumnOrder() {
return this;
}
protected INDArray create(DataBuffer data, int[] shape, long offset) {
if (this instanceof IComplexNDArray)
return Nd4j.createComplex(data, shape, offset);
else
return Nd4j.create(data, shape, offset);
}
/**
* Returns a linear view reference of shape
* 1,length(ndarray)
*
* @return the linear view of this ndarray
*/
@Override
public INDArray linearView() {
return reshape(this.ordering(), 1, this.length());
}
@Override
public void resetLinearView() {
}
@Override
public int elementWiseStride() {
/*
if(Shape.elementWiseStride(shapeInfo()) < 0 && !attemptedToFindElementWiseStride) {
INDArray reshapeAttempt = Shape.newShapeNoCopy(this,new int[]{1,length()}, ordering() == 'f');
if(reshapeAttempt != null && reshapeAttempt.elementWiseStride() > 0) {
Shape.setElementWiseStride(shapeInfo(), reshapeAttempt.stride(-1));
this.shapeInformation = Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), offset(),reshapeAttempt.stride(-1), ordering());
}
attemptedToFindElementWiseStride = true;
}
*/
return Shape.elementWiseStride(shapeInfoDataBuffer());
}
@Override
public int elementStride() {
return 1;
}
@Override
@Deprecated
public int majorStride() {
return stride(-1);
}
@Override
@Deprecated
public int secondaryStride() {
return majorStride();
}
@Override
public int tensorssAlongDimension(int... dimension) {
if (dimension == null || dimension.length == 0)
throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)");
if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)
return 1;
for (int i = 0; i < dimension.length; i++)
if (dimension[i] < 0)
dimension[i] += rank();
int[] tensorShape = ArrayUtil.keep(shape(), dimension);
long len = ArrayUtil.prodLong(tensorShape);
if (len == 0)
throw new IllegalStateException("Illegal length found after removing index");
int length = length();
if (length / len >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Tensors along dimension can not be >= Integer.MAX_VALUE");
return (int) (length / len);
}
@Override
public INDArray tensorAlongDimension(int index, int... dimension) {
if (dimension == null || dimension.length == 0)
throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)");
if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)
return this;
for (int i = 0; i < dimension.length; i++)
if (dimension[i] < 0)
dimension[i] += rank();
//dedup
if (dimension.length > 1)
dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension))));
if (dimension.length > 1) {
Arrays.sort(dimension);
}
int tads = tensorssAlongDimension(dimension);
if (index >= tads)
throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads);
if (dimension.length == 1) {
if (dimension[0] == 0 && isColumnVector()) {
return this.transpose();
} else if (dimension[0] == 1 && isRowVector()) {
return this;
}
}
Pair tadInfo =
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
DataBuffer shapeInfo = tadInfo.getFirst();
int[] shape = Shape.shape(shapeInfo);
int[] stride = Shape.stride(shapeInfo).asInt();
long offset = offset() + tadInfo.getSecond().getLong(index);
INDArray toTad = Nd4j.create(data(), shape, stride, offset);
BaseNDArray baseNDArray = (BaseNDArray) toTad;
//preserve immutability
char newOrder = Shape.getOrder(shape, stride, 1);
int ews = baseNDArray.shapeInfoDataBuffer().getInt(baseNDArray.shapeInfoDataBuffer().length() - 2);
//TAD always calls permute. Permute EWS is always -1. This is not true
// for row vector shapes though.
if (!Shape.isRowVectorShape(baseNDArray.shapeInfoDataBuffer()))
ews = -1;
// we create new shapeInfo with possibly new ews & order
/**
* NOTE HERE THAT ZERO IS PRESET FOR THE OFFSET AND SHOULD STAY LIKE THAT.
* Zero is preset for caching purposes.
* We don't actually use the offset defined in the
* shape info data buffer.
* We calculate and cache the offsets separately.
*
*/
baseNDArray.setShapeInformation(
Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, ews, newOrder));
return toTad;
}
/**
* Get the vector along a particular dimension
*
* @param index the index of the vector to getScalar
* @param dimension the dimension to getScalar the vector from
* @return the vector along a particular dimension
*/
@Override
public INDArray javaTensorAlongDimension(int index, int... dimension) {
return doTad(index, dimension);
}
private void setShapeInformation(Pair shapeInfo) {
this.shapeInformation = shapeInfo.getFirst();
this.javaShapeInformation = shapeInfo.getSecond();
}
private INDArray doTad(int index, int... dimension) {
if (dimension == null || dimension.length == 0)
throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)");
if (dimension.length >= rank())
return this;
for (int i = 0; i < dimension.length; i++)
if (dimension[i] < 0)
dimension[i] += rank();
if (dimension.length > 1)
Arrays.sort(dimension);
int tads = tensorssAlongDimension(dimension);
if (index >= tads)
throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads);
if (dimension.length == 1) {
if (dimension[0] == 0 && isColumnVector()) {
return this.transpose();
} else if (dimension[0] == 1 && isRowVector()) {
return this;
}
}
int[] tensorShape = ArrayUtil.keep(shape(), dimension);
int[] reverseDimensions = ArrayUtil.reverseCopy(dimension);
int[] remove = ArrayUtil.removeIndex(ArrayUtil.range(0, rank()), dimension);
int[] newPermuteDims = Ints.concat(remove, reverseDimensions);
int[] finalPermuteDims = tadFinalPermuteDimensions[dimension.length];
INDArray permuted = permute(newPermuteDims);
int sliceIdx = NDArrayMath.sliceOffsetForTensor(index, permuted, tensorShape);
INDArray ret2 = permuted.slice(sliceIdx);
if (dimension.length == tensorShape.length && ArrayUtil.prod(tensorShape) == ret2.length()) {
if (dimension.length == 1 && ret2.isRowVector())
return ret2;
if (finalPermuteDims.length != ret2.rank()) {
finalPermuteDims = new int[ret2.rank()];
int count = 0;
for (int i = finalPermuteDims.length - 1; i >= 0; i--)
finalPermuteDims[count++] = i;
}
return ret2.permutei(finalPermuteDims);
}
int length = ArrayUtil.prod(tensorShape);
int tensorLength = ArrayUtil.prod(tensorShape);
long offset = index * tensorLength / NDArrayMath.lengthPerSlice(ret2);
if (sliceIdx == 0 && length == NDArrayMath.lengthPerSlice(ret2)) {
// FIXME: LONG
ret2 = ret2.slice((int) offset);
if (dimension.length == 1 && ret2.isRowVector())
return ret2;
return ret2.permutei(finalPermuteDims);
}
else if (length == NDArrayMath.lengthPerSlice(ret2)) {
offset -= ret2.slices() * (offset / ret2.slices());
// FIXME: LONG
ret2 = ret2.slice((int) offset);
if (dimension.length == 1 && ret2.isRowVector())
return ret2;
return ret2.permutei(finalPermuteDims);
}
while (ret2.length() > length) {
sliceIdx = NDArrayMath.sliceOffsetForTensor(index, ret2, tensorShape);
sliceIdx -= ret2.slices() * (sliceIdx / ret2.slices());
ret2 = ret2.slice(sliceIdx);
}
if (dimension.length == 1 && ret2.isRowVector())
return ret2;
return ret2.permutei(finalPermuteDims);
}
/**
* Returns the number of possible vectors for a given dimension
*
* @param dimension the dimension to calculate the number of vectors for
* @return the number of possible vectors along a dimension
*/
@Override
public int vectorsAlongDimension(int dimension) {
if (dimension == 0 && isVector() || isRowVector())
return 1;
if (size(dimension) == 1 && !isVector()) {
for (int i = dimension; i < rank(); i++) {
if (size(i) != 1)
return vectorsAlongDimension(i);
}
return length();
} else if (size(0) == 1 && !isVector()) {
int realDimension = rank() - getLeadingOnes();
int length = length();
if (length / size(realDimension) >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE");
return (int) (length / size(realDimension));
}
int length = length();
if (dimension >= Shape.rank(javaShapeInformation)) {
if (length / size(Shape.rank(javaShapeInformation) - 1) >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE");
return (int) (length / size(Shape.rank(javaShapeInformation) - 1));
}
if (length / size(dimension) >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE");
return (int) (length / size(dimension));
}
/**
* Get the vector along a particular dimension
*
* @param index the index of the vector to get
* @param dimension the dimension to get the vector from
* @return the vector along a particular dimension
*/
@Override
public INDArray vectorAlongDimension(int index, int dimension) {
if (dimension < 0)
dimension = Shape.rank(javaShapeInformation) + dimension;
//return the whole thing
if (dimension == Shape.rank(javaShapeInformation) - 1 && size(dimension) == 1 && rank() > 2
|| rank() > 2 && dimension == 0 && size(dimension) == 1) {
return this;
}
INDArray ret = tensorAlongDimension(index, dimension);
if (isMatrix() && ret.isVector() && dimension == 1 && !ret.isRowVector())
return ret.reshape(ArrayUtil.reverseCopy(ret.shape()));
else if (isMatrix() && ret.isVector() && dimension == 0 && !ret.isColumnVector())
return ret.reshape(ArrayUtil.reverseCopy(ret.shape()));
return ret;
}
@Override
public void setOrder(char order) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), 0,
elementWiseStride(), order));
}
@Override
public void setShape(int... shape) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride(), 0, elementWiseStride(),
ordering()));
}
@Override
public void setStride(int[] stride) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride, 0, elementWiseStride(),
ordering()));
}
@Override
public void setShapeAndStride(int[] shape, int[] stride) {
setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, -1, ordering()));
}
/**
* Cumulative sum along a dimension
*
* @param dimension the dimension to perform cumulative sum along
* @return the cumulative sum along the specified dimension
*/
@Override
public INDArray cumsumi(int dimension) {
if (isVector()) {
double s = 0.0;
for (int i = 0; i < length(); i++) {
s += getDouble(i);
putScalar(i, s);
}
} else if (dimension == Integer.MAX_VALUE) {
INDArray flattened = ravel();
double prevVal = flattened.getDouble(0);
for (int i = 1; i < flattened.length(); i++) {
double d = prevVal + flattened.getDouble(i);
flattened.putScalar(i, d);
prevVal = d;
}
return flattened;
} else {
for (int i = 0; i < vectorsAlongDimension(dimension); i++) {
INDArray vec = vectorAlongDimension(i, dimension);
vec.cumsumi(0);
}
}
return this;
}
@Override
public Number normmaxNumber() {
return normmax(Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber normmaxComplex() {
throw new UnsupportedOperationException();
}
@Override
public Number norm2Number() {
return norm2(Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber norm2Complex() {
throw new UnsupportedOperationException();
}
@Override
public Number norm1Number() {
return norm1(Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber norm1Complex() {
throw new UnsupportedOperationException();
}
@Override
public Number stdNumber() {
return std(Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber stdComplex() {
throw new UnsupportedOperationException();
}
@Override
public Number prodNumber() {
return prod(Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber prodComplex() {
throw new UnsupportedOperationException();
}
@Override
public Number meanNumber() {
return mean(Integer.MAX_VALUE).getDouble(0);
}
@Override
public Number ameanNumber() {
return amean(Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber meanComplex() {
throw new UnsupportedOperationException();
}
@Override
public Number varNumber() {
return var(Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber varComplex() {
throw new UnsupportedOperationException();
}
@Override
public Number maxNumber() {
return max(Integer.MAX_VALUE).getDouble(0);
}
@Override
public Number amaxNumber() {
return amax(Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber maxComplex() {
throw new UnsupportedOperationException();
}
@Override
public Number minNumber() {
return min(Integer.MAX_VALUE).getDouble(0);
}
@Override
public Number aminNumber() {
return amin(Integer.MAX_VALUE).getDouble(0);
}
@Override
public Number scan(Condition condition) {
MatchCondition op = new MatchCondition(this, condition);
return Nd4j.getExecutioner().exec(op, Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber minComplex() {
throw new UnsupportedOperationException();
}
@Override
public Number sumNumber() {
return sum(Integer.MAX_VALUE).getDouble(0);
}
/**
* Returns entropy value for this INDArray
* @return
*/
@Override
public Number entropyNumber() {
return entropy(Integer.MAX_VALUE).getDouble(0);
}
/**
* Returns non-normalized Shannon entropy value for this INDArray
* @return
*/
@Override
public Number shannonEntropyNumber() {
return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
}
/**
* Returns log entropy value for this INDArray
* @return
*/
@Override
public Number logEntropyNumber() {
return logEntropy(Integer.MAX_VALUE).getDouble(0);
}
@Override
public IComplexNumber sumComplex() {
throw new UnsupportedOperationException();
}
/**
* Cumulative sum along a dimension (in place)
*
* @param dimension the dimension to perform cumulative sum along
* @return the cumulative sum along the specified dimension
*/
@Override
public INDArray cumsum(int dimension) {
return dup().cumsumi(dimension);
}
/**
* Assign all of the elements in the given
* ndarray to this ndarray
*
* @param arr the elements to assign
* @return this
*/
@Override
public INDArray assign(final INDArray arr) {
Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.Set(this, arr, this, length()));
return this;
}
@Override
public INDArray putScalar(int i, double value) {
if (i < 0)
i += rank();
if (isScalar()) {
autoProcessScalarCall();
data.put(i, value);
return this;
}
if (isRowVector()) {
return putScalar(0, i, value);
} else if (isColumnVector()) {
return putScalar(i, 0, value);
}
int[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
return putScalar(indexes, value);
}
@Override
public INDArray putScalar(int i, float value) {
return putScalar(i, (double) value);
}
@Override
public INDArray putScalar(int i, int value) {
return putScalar(i, (double) value);
}
@Override
public INDArray putScalar(int[] indexes, double value) {
Nd4j.getCompressor().autoDecompress(this);
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] < 0)
indexes[i] += rank();
}
if (indexes.length == 1) {
return putScalar(indexes[0], value);
} else if (indexes.length == 2) {
return putScalar(indexes[0], indexes[1], value);
} else if (indexes.length == 3) {
return putScalar(indexes[0], indexes[1], indexes[2], value);
} else if (indexes.length == 4) {
return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value);
} else {
autoProcessScalarCall();
long offset = Shape.getOffset(javaShapeInformation, indexes);
data.put(offset, value);
}
return this;
}
@Override
public INDArray putScalar(int row, int col, double value) {
Nd4j.getCompressor().autoDecompress(this);
autoProcessScalarCall();
if (rank() > 2)
throw new IllegalStateException("Cannot use putScalar(int,int,double) on a rank " + rank() + " INDArray");
long offset = Shape.getOffsetUnsafe(javaShapeInformation, row, col);
data.put(offset, value);
return this;
}
@Override
public INDArray putScalar(int dim0, int dim1, int dim2, double value) {
Nd4j.getCompressor().autoDecompress(this);
autoProcessScalarCall();
if (rank() != 3)
throw new IllegalStateException(
"Cannot use putScalar(int,int,int,double) on a rank " + rank() + " INDArray");
long offset = 0; // Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2);
int size_0 = javaShapeInformation[1];
int size_1 = javaShapeInformation[1 + 1];
int size_2 = javaShapeInformation[1 + 2];
if (size_0 != 1)
offset += dim0 * javaShapeInformation[1 + 0 + 3];
if (size_1 != 1)
offset += dim1 * javaShapeInformation[1 + 1 + 3];
if (size_2 != 1)
offset += dim2 * javaShapeInformation[1 + 2 + 3];
data.put(offset, value);
return this;
}
@Override
public INDArray putScalar(int dim0, int dim1, int dim2, int dim3, double value) {
Nd4j.getCompressor().autoDecompress(this);
autoProcessScalarCall();
if (rank() != 4)
throw new IllegalStateException(
"Cannot use putScalar(int,int,int,int,double) on a rank " + rank() + " INDArray");
long offset = Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2, dim3);
data.put(offset, value);
return this;
}
@Override
public INDArray putScalar(int[] indexes, float value) {
return putScalar(indexes, (double) value);
}
@Override
public INDArray putScalar(int[] indexes, int value) {
return putScalar(indexes, (double) value);
}
/**
* Returns an ndarray with 1 if the element is epsilon equals
*
* @param other the number to compare
* @return a copied ndarray with the given
* binary conditions
*/
@Override
public INDArray eps(Number other) {
return dup().epsi(other);
}
/**
* Returns an ndarray with 1 if the element is epsilon equals
*
* @param other the number to compare
* @return a ndarray with the given
* binary conditions
*/
@Override
public INDArray epsi(Number other) {
INDArray otherArr = Nd4j.valueArrayOf(shape(), other.doubleValue());
return epsi(otherArr);
}
/**
* epsilon equals than comparison:
* If the given number is less than the
* comparison number the item is 0 otherwise 1
*
* @param other the number to compare
* @return
*/
@Override
public INDArray eps(INDArray other) {
return dup().epsi(other);
}
/**
* In place epsilon equals than comparison:
* If the given number is less than the
* comparison number the item is 0 otherwise 1
*
* @param other the number to compare
* @return
*/
@Override
public INDArray epsi(INDArray other) {
Nd4j.getExecutioner().exec(new Eps(this, other, this, length()));
return this;
}
@Override
public INDArray lt(Number other) {
return dup().lti(other);
}
@Override
public INDArray lte(Number other) {
return dup().ltei(other);
}
@Override
public INDArray lti(Number other) {
Nd4j.getExecutioner().exec(new ScalarLessThan(this, other));
return this;
}
@Override
public INDArray ltei(Number other) {
Nd4j.getExecutioner().exec(new ScalarLessThanOrEqual(this, other));
return this;
}
@Override
public INDArray eq(Number other) {
return dup().eqi(other);
}
@Override
public INDArray eqi(Number other) {
Nd4j.getExecutioner().exec(new ScalarEquals(this, other));
return this;
}
@Override
public INDArray gt(Number other) {
return dup().gti(other);
}
@Override
public INDArray gte(Number other) {
return dup().gtei(other);
}
@Override
public INDArray gtei(Number other) {
Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, other));
return this;
}
@Override
public INDArray gti(Number other) {
Nd4j.getExecutioner().exec(new ScalarGreaterThan(this, other));
return this;
}
@Override
public INDArray lt(INDArray other) {
return dup().lti(other);
}
@Override
public INDArray lti(INDArray other) {
Nd4j.getExecutioner().exec(new OldLessThan(this, other, this, length()));
return this;
}
@Override
public INDArray neq(Number other) {
return dup().neqi(other);
}
@Override
public INDArray neqi(Number other) {
Nd4j.getExecutioner().exec(new ScalarNotEquals(this, other));
return this;
}
@Override
public INDArray neq(INDArray other) {
return dup().neqi(other);
}
@Override
public INDArray neqi(INDArray other) {
Nd4j.getExecutioner().exec(new OldNotEqualTo(this, other, this, length()));
return this;
}
@Override
public INDArray eq(INDArray other) {
return dup().eqi(other);
}
@Override
public INDArray eqi(INDArray other) {
Nd4j.getExecutioner().exec(new OldEqualTo(this, other, this, length()));
return this;
}
@Override
public INDArray gt(INDArray other) {
return dup().gti(other);
}
@Override
public INDArray gti(INDArray other) {
Nd4j.getExecutioner().exec(new OldGreaterThan(this, other, this, length()));
return this;
}
/**
* Negate each element.
*/
@Override
public INDArray neg() {
return Nd4j.getExecutioner().exec(new Negative(this, Nd4j.createUninitialized(this.shape(), this.ordering())))
.z();
}
/**
* Negate each element (in-place).
*/
@Override
public INDArray negi() {
Nd4j.getExecutioner().exec(new Negative(this));
return this;
}
@Override
public INDArray rdiv(Number n, INDArray result) {
return rdivi(n, result);
}
@Override
public INDArray rdivi(Number n, INDArray result) {
if (Double.isNaN(n.doubleValue()))
n = Nd4j.EPS_THRESHOLD;
Nd4j.getExecutioner().exec(new ScalarReverseDivision(this, null, result, result.length(), n));
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
@Override
public INDArray rsub(Number n, INDArray result) {
return rsubi(n, result);
}
@Override
public INDArray rsubi(Number n, INDArray result) {
if (Double.isNaN(n.doubleValue()))
n = Nd4j.EPS_THRESHOLD;
Nd4j.getExecutioner().exec(new ScalarReverseSubtraction(this, null, result, result.lengthLong(), n));
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
@Override
public INDArray div(Number n, INDArray result) {
return divi(n, result);
}
@Override
public INDArray divi(Number n, INDArray result) {
if (Double.isNaN(n.doubleValue()))
n = Nd4j.EPS_THRESHOLD;
Nd4j.getExecutioner().exec(new ScalarDivision(this, null, result, result.lengthLong(), n));
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
@Override
public INDArray mul(Number n, INDArray result) {
return muli(n, result);
}
@Override
public INDArray muli(Number n, INDArray result) {
if (Double.isNaN(n.doubleValue()))
n = Nd4j.EPS_THRESHOLD;
Nd4j.getExecutioner().exec(new ScalarMultiplication(this, null, result, result.lengthLong(), n));
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
@Override
public INDArray sub(Number n, INDArray result) {
return subi(n, result);
}
@Override
public INDArray subi(Number n, INDArray result) {
if (Double.isNaN(n.doubleValue()))
n = Nd4j.EPS_THRESHOLD;
Nd4j.getExecutioner().exec(new ScalarSubtraction(this, null, result, result.lengthLong(), n));
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
@Override
public INDArray add(Number n, INDArray result) {
return addi(n, result);
}
@Override
public INDArray addi(Number n, INDArray result) {
if (Double.isNaN(n.doubleValue()))
n = Nd4j.EPS_THRESHOLD;
Nd4j.getExecutioner().exec(new ScalarAdd(this, null, result, result.lengthLong(), n));
return result;
}
/**
* Returns the element at the specified row/column
* This will throw an exception if the
*
* @param row the row of the element to return
* @param column the row of the element to return
* @return a scalar indarray of the element at this index
*/
@Override
public INDArray getScalar(int row, int column) {
return getScalar(new int[] {row, column});
}
@Override
public INDArray dup() {
WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray");
if (this.isCompressed() && this.ordering() == Nd4j.order().charValue()) {
INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer());
ret.markAsCompressed(true);
return ret;
}
Nd4j.getCompressor().autoDecompress(this);
INDArray ret = Shape.toOffsetZeroCopy(this);
return ret;
}
@Override
public INDArray dup(char order) {
WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray");
if (this.isCompressed() && this.ordering() == order) {
INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer());
ret.markAsCompressed(true);
return ret;
}
Nd4j.getCompressor().autoDecompress(this);
return Shape.toOffsetZeroCopy(this, order);
}
/**
* Returns the elements at the the specified indices
*
* @param indices the indices to getScalar
* @return the array with the specified elements
*/
@Override
public int getInt(int... indices) {
return (int) getDouble(indices);
}
/**
* Returns the elements at the the specified indices
*
* @param indices the indices to get
* @return the array with the specified elements
*/
@Override
public double getDouble(int... indices) {
autoProcessScalarCall();
Nd4j.getCompressor().autoDecompress(this);
for (int i = 0; i < indices.length; i++) {
if (indices[i] < 0)
indices[i] += rank();
}
if (indices.length == 1) {
if (rank() == 1)
return Shape.getDouble(this, indices[0]);
else if (isRowVector())
return Shape.getDouble(this, 0, indices[0]);
else if (isColumnVector())
return Shape.getDouble(this, indices[0], 0);
else if (isScalar() && indices[0] == 0)
return data().getDouble(0);
else
throw new IllegalStateException("Indexes length must be > 1 for non vectors and scalars");
}
return Shape.getDouble(this, indices);
}
/**
* Returns the elements at the the specified indices
*
* @param indices the indices to get
* @return the array with the specified elements
*/
@Override
public float getFloat(int... indices) {
return (float) getDouble(indices);
}
/**
* Test whether a matrix is scalar.
*/
@Override
public boolean isScalar() {
if (Shape.rank(javaShapeInformation) == 0) {
return true;
} else if (Shape.rank(javaShapeInformation) > 2) {
return false;
} else if (Shape.rank(javaShapeInformation) == 1) {
return shapeOf().getInt(0) == 1;
} else if (Shape.rank(javaShapeInformation) == 2) {
return shapeOf().getInt(0) == 1 && shapeOf().getInt(1) == 1 || length() == 1;
}
else
return false;
}
/**
* Inserts the element at the specified index
*
* @param indices the indices to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int[] indices, INDArray element) {
Nd4j.getCompressor().autoDecompress(this);
if (!element.isScalar())
throw new IllegalArgumentException("Unable to insert anything but a scalar");
if (isRowVector() && indices[0] == 0 && indices.length == 2) {
int ix = Shape.offset(javaShapeInformation);
for (int i = 1; i < indices.length; i++)
ix += indices[i] * stride(i);
if (ix >= data.length())
throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices));
data.put(ix, element.getDouble(0));
} else {
int ix = Shape.offset(javaShapeInformation);
for (int i = 0; i < indices.length; i++)
if (size(i) != 1)
ix += indices[i] * stride(i);
if (ix >= data.length())
throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices));
data.put(ix, element.getDouble(0));
}
return this;
}
@Override
public INDArray match(INDArray comp, Condition condition) {
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp,condition)).z();
}
@Override
public INDArray match(Number comp, Condition condition) {
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp.doubleValue(),condition)).z();
}
@Override
public INDArray getWhere(INDArray comp, Condition condition) {
return BooleanIndexing.chooseFrom(new INDArray[]{this,comp},condition);
}
@Override
public INDArray getWhere(Number comp, Condition condition) {
return BooleanIndexing.chooseFrom(new INDArray[]{this},Arrays.asList(comp.doubleValue()),Collections.emptyList(),condition);
}
@Override
public INDArray putWhere(INDArray comp, INDArray put, Condition condition) {
Nd4j.getCompressor().autoDecompress(this);
MatchConditionTransform matchCondition = new MatchConditionTransform(this,comp,condition);
Nd4j.getExecutioner().exec(matchCondition);
return putWhereWithMask(matchCondition.z(),put);
}
@Override
public INDArray putWhere(Number comp, INDArray put, Condition condition) {
return putWhere(Nd4j.scalar(comp),put,condition);
}
@Override
public INDArray putWhere(Number comp, Number put, Condition condition) {
return putWhere(Nd4j.scalar(comp),Nd4j.scalar(put),condition);
}
@Override
public INDArray putWhereWithMask(INDArray mask, INDArray put) {
INDArray output = dup();
Nd4j.getExecutioner().exec(new Where(new INDArray[]{mask,this,put},new INDArray[]{output}));
return output;
}
@Override
public INDArray putWhereWithMask(INDArray mask, Number put) {
return putWhereWithMask(mask,Nd4j.scalar(put));
}
/**
* Inserts the element at the specified index
*
* @param i the row insert into
* @param j the column to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int i, int j, INDArray element) {
return put(new int[] {i, j}, element);
}
/**
* Inserts the element at the specified index
*
* @param i the row insert into
* @param j the column to insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int i, int j, Number element) {
return putScalar(new int[] {i, j}, element.doubleValue());
}
/**
* Assigns the given matrix (put) to the specified slice
*
* @param slice the slice to assign
* @param put the slice to put
* @return this for chainability
*/
@Override
public INDArray putSlice(int slice, INDArray put) {
Nd4j.getCompressor().autoDecompress(this);
if (isScalar()) {
assert put.isScalar() : "Invalid dimension. Can only insert a scalar in to another scalar";
put(0, put.getScalar(0));
return this;
} else if (isVector()) {
assert put.isScalar() || put.isVector() && put
.length() == length() : "Invalid dimension on insertion. Can only insert scalars input vectors";
if (put.isScalar())
putScalar(slice, put.getDouble(0));
else
for (int i = 0; i < length(); i++)
putScalar(i, put.getDouble(i));
return this;
}
assertSlice(put, slice);
INDArray view = slice(slice);
if (put.length() == 1)
putScalar(slice, put.getDouble(0));
else if (put.isVector())
for (int i = 0; i < put.length(); i++)
view.putScalar(i, put.getDouble(i));
else {
assert Shape.shapeEquals(view.shape(), put.shape());
INDArray linear = view;
INDArray putLinearView = put;
for (int i = 0; i < linear.length(); i++) {
linear.putScalar(i, putLinearView.getDouble(i));
}
}
return this;
}
protected void assertSlice(INDArray put, int slice) {
assert slice <= slices() : "Invalid slice specified " + slice;
int[] sliceShape = put.shape();
if (Shape.isRowVectorShape(sliceShape)) {
return;
} else {
int[] requiredShape = ArrayUtil.removeIndex(shape(), 0);
//no need to compare for scalar; primarily due to shapes either being [1] or length 0
if (put.isScalar())
return;
if (isVector() && put.isVector() && put.length() < length())
return;
//edge case for column vectors
if (Shape.isColumnVectorShape(sliceShape))
return;
if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape)
&& !Shape.isRowVectorShape(sliceShape))
throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ",
Arrays.toString(sliceShape), Arrays.toString(requiredShape)));
}
}
/**
* Returns true if this ndarray is 2d
* or 3d with a singleton element
*
* @return true if the element is a matrix, false otherwise
*/
public boolean isMatrix() {
int rank = rank();
return (rank == 2 && (size(0) != 1 && size(1) != 1));
}
@Override
@Deprecated
public int index(int row, int column) {
if (!isMatrix()) {
if (isColumnVector()) {
int idx = linearIndex(row);
return idx;
} else if (isRowVector()) {
int idx = linearIndex(column);
return idx;
} else
throw new IllegalStateException("Unable to get row/column from a non matrix");
}
return Shape.offset(javaShapeInformation) + (row * stride(0) + column * stride(1));
}
protected INDArray newShape(int[] newShape, char ordering) {
return create(data(), newShape, stride(), Shape.offset(javaShapeInformation));
}
protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset, char ordering) {
if (this instanceof IComplexNDArray)
return Nd4j.createComplex(data, newShape, newStrides, offset, ordering);
else
return Nd4j.create(data, newShape, newStrides, offset, ordering);
}
protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset) {
if (this instanceof IComplexNDArray)
return Nd4j.createComplex(data, newShape, newStrides, offset);
else
return Nd4j.create(data, newShape, newStrides, offset);
}
protected INDArray create(int[] shape) {
if (this instanceof IComplexNDArray)
return Nd4j.createComplex(shape, getStrides(shape, Nd4j.order()), 0);
else
return Nd4j.create(shape, getStrides(shape, Nd4j.order()), 0);
}
protected INDArray create(int[] shape, int[] strides, long offset) {
if (this instanceof IComplexNDArray)
return Nd4j.createComplex(shape, strides, offset);
else
return Nd4j.create(shape, strides, offset);
}
protected int[] getStrides(int[] shape, char ordering) {
return Nd4j.getStrides(shape, ordering);
}
/**
* Returns the square of the Euclidean distance.
*/
@Override
public double squaredDistance(INDArray other) {
double d2 = distance2(other);
return d2 * d2;
}
/**
* Returns the (euclidean) distance.
*/
@Override
public double distance2(INDArray other) {
Nd4j.getCompressor().autoDecompress(this);
return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue();
}
/**
* Returns the (1-norm) distance.
*/
@Override
public double distance1(INDArray other) {
Nd4j.getCompressor().autoDecompress(this);
return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue();
}
@Override
public INDArray get(INDArray indices) {
if(indices.rank() > 2) {
throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
}
if(indices.rows() == rank()) {
INDArray ret = Nd4j.create(indices.columns());
for(int i = 0; i < indices.columns(); i++) {
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
ret.putScalar(i,getDouble(specifiedIndex));
}
return ret;
}
else {
List arrList = new ArrayList<>();
if(indices.isMatrix() || indices.isColumnVector()
|| (indices.isScalar() && indices.rank() == 2)) { // we need this for compatibility with legacy code
for(int i = 0; i < indices.rows(); i++) {
if(i == 0) {
INDArray row = indices.getRow(i);
for(int j = 0; j < row.length(); j++) {
arrList.add(slice(row.getInt(j)));
}
}
else {
INDArray row = indices.slice(i);
for(int j = 0; j < row.length(); j++) {
INDArray put = arrList.get(j).slice(row.getInt(j));
put = put.reshape(Ints.concat(new int[]{1},put.shape()));
arrList.set(j,put);
}
}
}
}
else if(indices.isRowVector()) {
for(int i = 0; i < indices.length(); i++) {
INDArray add = slice(indices.getInt(i));
add = add.reshape(Ints.concat(new int[] {1,},add.shape()));
arrList.add(add);
}
}
return Nd4j.concat(0,arrList.toArray(new INDArray[arrList.size()]));
}
}
@Override
public INDArray get(List> indices) {
INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
for(int i = 0; i < indArrayIndices.length; i++) {
indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
}
boolean hasNext = true;
Generator>> iterate = SpecifiedIndex.iterate(indArrayIndices);
List resultList = new ArrayList<>();
while(hasNext) {
try {
List> next = iterate.next();
int[][] nextArr = new int[next.size()][];
for(int i = 0; i < next.size(); i++) {
nextArr[i] = Ints.toArray(next.get(i));
}
int[] curr = Ints.concat(nextArr);
INDArray currSlice = this;
for(int j = 0; j < curr.length; j++) {
currSlice = currSlice.slice(curr[j]);
}
//slice drops the first dimension, this adds a 1 to match normal numpy behavior
currSlice = currSlice.reshape(Ints.concat(new int[]{1},currSlice.shape()));
resultList.add(currSlice);
}
catch(NoSuchElementException e) {
hasNext = false;
}
}
return Nd4j.concat(0,resultList.toArray(new INDArray[resultList.size()]));
}
@Override
public INDArray put(List> indices, INDArray element) {
INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
for(int i = 0; i < indArrayIndices.length; i++) {
indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
}
boolean hasNext = true;
Generator>> iterate = SpecifiedIndex.iterate(indArrayIndices);
if(indices.size() == rank()) {
NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
while(hasNext) {
try {
List> next = iterate.next();
int[][] nextArr = new int[next.size()][];
for(int i = 0; i < next.size(); i++) {
nextArr[i] = Ints.toArray(next.get(i));
}
int[] curr = Ints.concat(nextArr);
putScalar(curr,element.getDouble(ndIndexIterator.next()));
}
catch(NoSuchElementException e) {
hasNext = false;
}
}
}
else {
if(indices.size() >= 2) {
while(hasNext) {
try {
List> next = iterate.next();
int[][] nextArr = new int[next.size()][];
for(int i = 0; i < next.size(); i++) {
nextArr[i] = Ints.toArray(next.get(i));
}
int[] curr = Ints.concat(nextArr);
INDArray currSlice = this;
for(int j = 0; j < curr.length; j++) {
currSlice = currSlice.slice(curr[j]);
}
Nd4j.getExecutioner().exec(new Assign(new INDArray[]{currSlice,element},new INDArray[]{currSlice}));
}
catch(NoSuchElementException e) {
hasNext = false;
}
}
}
}
return this;
}
@Override
public INDArray put(INDArray indices, INDArray element) {
if(indices.rank() > 2) {
throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
}
if(indices.rows() == rank()) {
NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
for(int i = 0; i < indices.columns(); i++) {
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next()));
}
}
else {
List arrList = new ArrayList<>();
if(indices.isMatrix() || indices.isColumnVector()) {
for(int i = 0; i < indices.rows(); i++) {
INDArray row = indices.getRow(i);
for(int j = 0; j < row.length(); j++) {
INDArray slice = slice(row.getInt(j));
Nd4j.getExecutioner().exec(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
arrList.add(slice(row.getInt(j)));
}
}
}
else if(indices.isRowVector()) {
for(int i = 0; i < indices.length(); i++) {
arrList.add(slice(indices.getInt(i)));
}
}
}
return this;
}
@Override
public INDArray put(INDArrayIndex[] indices, INDArray element) {
Nd4j.getCompressor().autoDecompress(this);
if (indices[0] instanceof SpecifiedIndex && element.isVector()) {
indices[0].reset();
int cnt = 0;
while (indices[0].hasNext()) {
long idx = indices[0].next();
// FIXME: LONG
putScalar((int) idx, element.getDouble(cnt));
cnt++;
}
return this;
} else {
return get(indices).assign(element);
}
}
@Override
public INDArray put(INDArrayIndex[] indices, Number element) {
Nd4j.getCompressor().autoDecompress(this);
INDArray get = get(indices);
for (int i = 0; i < get.length(); i++)
get.putScalar(i, element.doubleValue());
return this;
}
/**
* Mainly here for people coming from numpy.
* This is equivalent to a call to permute
*
* @param dimension the dimension to swap
* @param with the one to swap it with
* @return the swapped axes view
*/
@Override
public INDArray swapAxes(int dimension, int with) {
int[] shape = ArrayUtil.range(0, shape().length);
shape[dimension] = with;
shape[with] = dimension;
return permute(shape);
}
@Override
public boolean isView() {
/*
We don't really use Shape offset value anywhere
And it's possible to be not a view, and have non-empty originalBuffer
*/
// length/data.length can be different in case of Threshold conversion
return Shape.offset(javaShapeInformation) > 0
|| (length() < data().length() && data.dataType() != DataBuffer.Type.INT)
|| data().originalDataBuffer() != null;
}
@Override
public boolean isSparse() {
return false;
}
@Override
public DataBuffer data() {
return data;
}
@Override
public void setData(DataBuffer data) {
this.data = data;
}
/**
* Number of slices: aka shape[0]
*
* @return the number of slices
* for this nd array
*/
@Override
public int slices() {
if (isRowVector())
return length();
return size(0);
}
@Override
public INDArray subArray(ShapeOffsetResolution resolution) {
Nd4j.getCompressor().autoDecompress(this);
long[] offsets = resolution.getOffsets();
int[] shape = LongUtils.toInts(resolution.getShapes());
int[] stride = LongUtils.toInts(resolution.getStrides());
// if (offset() + resolution.getOffset() >= Integer.MAX_VALUE)
// throw new IllegalArgumentException("Offset of array can not be >= Integer.MAX_VALUE");
long offset = (offset() + resolution.getOffset());
int n = shape.length;
// FIXME: shapeInfo should be used here
if (shape.length < 1)
return create(Nd4j.createBufferDetached(shape));
if (offsets.length != n)
throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets));
if (stride.length != n)
throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride));
if (shape.length == rank() && Shape.contentEquals(shape, shapeOf())) {
if (ArrayUtil.isZero(offsets)) {
return this;
} else {
throw new IllegalArgumentException("Invalid subArray offsets");
}
}
char newOrder = Shape.getOrder(shape, stride, 1);
return create(data, Arrays.copyOf(shape, shape.length), stride, offset, newOrder);
}
@Override
public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
Nd4j.getCompressor().autoDecompress(this);
int n = shape.length;
// FIXME: shapeInfo should be used here
if (shape.length < 1)
return create(Nd4j.createBufferDetached(shape));
if (offsets.length != n)
throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets));
if (stride.length != n)
throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride));
if (Shape.contentEquals(shape, shapeOf())) {
if (ArrayUtil.isZero(offsets)) {
return this;
} else {
throw new IllegalArgumentException("Invalid subArray offsets");
}
}
long[] dotProductOffsets = offsets;
int[] dotProductStride = stride;
long offset = Shape.offset(javaShapeInformation) + NDArrayIndex.offset(dotProductStride, dotProductOffsets);
if (offset >= data().length())
offset = ArrayUtil.sumLong(offsets);
return create(data, Arrays.copyOf(shape, shape.length), stride, offset, ordering());
}
protected INDArray create(DataBuffer buffer) {
return Nd4j.create(buffer);
}
@Override
public INDArray cond(Condition condition) {
return dup().condi(condition);
}
@Override
public INDArray condi(Condition condition) {
Nd4j.getCompressor().autoDecompress(this);
INDArray linear = this;
for (int i = 0; i < length(); i++) {
boolean met = condition.apply(linear.getDouble(i));
linear.putScalar(i, met ? 1 : 0);
}
return this;
}
protected void init(int[] shape, int[] stride) {
//default row vector
if (shape.length == 1) {
init(new int[] {1, shape[0]}, new int[] {1, stride[0]});
}
//null character
if (ordering() == '\u0000') {
Shape.setOrder(shapeInfo(), Nd4j.order());
throw new IllegalStateException("setOrder() shouldn't ever happen here");
}
}
@Override
public INDArray getScalar(int i) {
return Nd4j.scalar(getDouble(i));
}
/**
* Do a row wise op (a,s,m,d)
* a : add
* s : subtract
* m : multiply
* d : divide
* h : reverse subtraction
* t : reverse division
*
* @param columnVector the column vector
* @param operation the operation
* @return
*/
protected INDArray doColumnWise(INDArray columnVector, char operation) {
Nd4j.getCompressor().autoDecompress(this);
if(columnVector.isScalar()) {
switch (operation) {
case 'a':
addi(columnVector.getDouble(0));
break;
case 'p':
assign(columnVector.getDouble(0));
break;
case 's':
subi(columnVector.getDouble(0));
break;
case 'm':
muli(columnVector.getDouble(0));
break;
case 'd':
divi(columnVector.getDouble(0));
break;
case 'h':
rsubi(columnVector.getDouble(0));
break;
case 't':
rdivi(columnVector.getDouble(0));
break;
}
return this;
}
//Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0)
if (!columnVector.isColumnVector() || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) {
throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape())
+ ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")");
}
if (columnVector.data().sameUnderlyingData(data()))
return doColumnWise(columnVector.dup(), operation);
if (isVector()) {
switch (operation) {
case 'a':
addi(columnVector);
break;
case 'p':
assign(columnVector);
break;
case 's':
subi(columnVector);
break;
case 'm':
muli(columnVector);
break;
case 'd':
divi(columnVector);
break;
case 'h':
rsubi(columnVector);
break;
case 't':
rdivi(columnVector);
break;
}
return this;
}
if (rows() == 1 && columnVector.isScalar()) {
applyScalarOp(columnVector, operation);
} else {
// special optimization case, broadcast turns into ScalarOp Along Dimension
if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'c' && columnVector.elementWiseStride() == 1) {
switch (operation) {
case 'a': {
ScalarAdd op = new ScalarAdd(this, columnVector, this, this.length(), 0.0);
op.setDimension(1);
Nd4j.getExecutioner().exec(op);
break;
}
case 'p': {
ScalarSet op = new ScalarSet(this, columnVector, this, this.length(), 0.0);
op.setDimension(1);
Nd4j.getExecutioner().exec(op);
break;
}
case 's': {
ScalarSubtraction op = new ScalarSubtraction(this, columnVector, this, this.length(), 0.0);
op.setDimension(1);
Nd4j.getExecutioner().exec(op);
break;
}
case 'm': {
ScalarMultiplication op =
new ScalarMultiplication(this, columnVector, this, this.length(), 0.0);
op.setDimension(1);
Nd4j.getExecutioner().exec(op);
break;
}
case 'd': {
ScalarDivision op = new ScalarDivision(this, columnVector, this, this.length(), 0.0);
op.setDimension(1);
Nd4j.getExecutioner().exec(op);
break;
}
case 'h': {
ScalarReverseSubtraction op =
new ScalarReverseSubtraction(this, columnVector, this, this.length(), 0.0);
op.setDimension(1);
Nd4j.getExecutioner().exec(op);
break;
}
case 't': {
ScalarReverseDivision op =
new ScalarReverseDivision(this, columnVector, this, this.length(), 0.0);
op.setDimension(1);
Nd4j.getExecutioner().exec(op);
break;
}
}
} else {
applyBroadcastOp(columnVector, operation);
}
}
return this;
}
@Override
@Deprecated
public boolean isCleanedUp() {
return false;
}
@Override
@Deprecated
public void cleanup() {
if (Nd4j.shouldInstrument)
Nd4j.getInstrumentation().log(this, Instrumentation.DESTROYED);
}
/**
* Do a row wise op (a,s,m,d)
* a : add
* s : subtract
* m : multiply
* d : divide
* h : reverse subtraction
* t : reverse division
*
* @param rowVector the row vector
* @param operation the operation
* @return
*/
protected INDArray doRowWise(INDArray rowVector, final char operation) {
Nd4j.getCompressor().autoDecompress(this);
if(rowVector.isScalar()) {
switch (operation) {
case 'a':
addi(rowVector.getDouble(0));
break;
case 'p':
assign(rowVector.getDouble(0));
break;
case 's':
subi(rowVector.getDouble(0));
break;
case 'm':
muli(rowVector.getDouble(0));
break;
case 'd':
divi(rowVector.getDouble(0));
break;
case 'h':
rsubi(rowVector.getDouble(0));
break;
case 't':
rdivi(rowVector.getDouble(0));
break;
}
return this;
}
//Input validation: require (a) rowVector to actually be a row vector, and (b) this.size(1) to match rowVector.size(1)
if (!rowVector.isRowVector() || this.rank() > 1 && rowVector.rank() > 1 && this.size(1) != rowVector.size(1) || rowVector.length() <= 1) {
throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape())
+ ", row vector shape =" + Arrays.toString(rowVector.shape()) + ")");
}
if (rowVector.data().sameUnderlyingData(data()))
return doRowWise(rowVector.dup(), operation);
if (isVector()) {
switch (operation) {
case 'a':
addi(rowVector);
break;
case 'p':
assign(rowVector);
break;
case 's':
subi(rowVector);
break;
case 'm':
muli(rowVector);
break;
case 'd':
divi(rowVector);
break;
case 'h':
rsubi(rowVector);
break;
case 't':
rdivi(rowVector);
break;
}
return this;
}
if (rank() == 2 && columns() == 1 && rowVector.isScalar()) {
if (this instanceof IComplexNDArray) {
applyScalarOp(rowVector, operation);
}
} else {
// special optimization case, broadcast turns into ScalarOp Along Dimension
if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'f' && rowVector.elementWiseStride() == 1) {
switch (operation) {
case 'a': {
ScalarAdd op = new ScalarAdd(this, rowVector, this, this.length(), 0.0);
op.setDimension(0);
Nd4j.getExecutioner().exec(op);
break;
}
case 'p': {
ScalarSet op = new ScalarSet(this, rowVector, this, this.length(), 0.0);
op.setDimension(0);
Nd4j.getExecutioner().exec(op);
break;
}
case 's': {
ScalarSubtraction op = new ScalarSubtraction(this, rowVector, this, this.length(), 0.0);
op.setDimension(0);
Nd4j.getExecutioner().exec(op);
break;
}
case 'm': {
ScalarMultiplication op = new ScalarMultiplication(this, rowVector, this, this.length(), 0.0);
op.setDimension(0);
Nd4j.getExecutioner().exec(op);
break;
}
case 'd': {
ScalarDivision op = new ScalarDivision(this, rowVector, this, this.length(), 0.0);
op.setDimension(0);
Nd4j.getExecutioner().exec(op);
break;
}
case 'h': {
ScalarReverseSubtraction op =
new ScalarReverseSubtraction(this, rowVector, this, this.length(), 0.0);
op.setDimension(0);
Nd4j.getExecutioner().exec(op);
break;
}
case 't': {
ScalarReverseDivision op = new ScalarReverseDivision(this, rowVector, this, this.length(), 0.0);
op.setDimension(0);
Nd4j.getExecutioner().exec(op);
break;
}
}
} else {
applyBroadcastOp(rowVector, operation);
}
}
return this;
}
private void applyBroadcastOp(INDArray vector, final char operation) {
Nd4j.getCompressor().autoDecompress(this);
int alongDimension = Shape.isRowVectorShape(vector.shape()) ? 1 : 0;
// FIXME: probably this is wrong, because strict equality is always false in current DataBuffer mechanics
if (this.data() == vector.data())
vector = vector.dup();
switch (operation) {
case 'a':
Nd4j.getExecutioner().exec(new BroadcastAddOp(this, vector, this, alongDimension), alongDimension);
return;
case 's':
Nd4j.getExecutioner().exec(new BroadcastSubOp(this, vector, this, alongDimension), alongDimension);
return;
case 'm':
Nd4j.getExecutioner().exec(new BroadcastMulOp(this, vector, this, alongDimension), alongDimension);
return;
case 'd':
Nd4j.getExecutioner().exec(new BroadcastDivOp(this, vector, this, alongDimension), alongDimension);
return;
case 'h':
Nd4j.getExecutioner().exec(new BroadcastRSubOp(this, vector, this, alongDimension), alongDimension);
return;
case 't':
Nd4j.getExecutioner().exec(new BroadcastRDivOp(this, vector, this, alongDimension), alongDimension);
return;
case 'p':
Nd4j.getExecutioner().exec(new BroadcastCopyOp(this, vector, this, alongDimension), alongDimension);
return;
default:
throw new UnsupportedOperationException("Unknown operation: " + operation);
}
}
private void applyScalarOp(INDArray vector, char operation) {
Nd4j.getCompressor().autoDecompress(this);
if (this instanceof IComplexNDArray) {
IComplexNDArray row = (IComplexNDArray) vector;
switch (operation) {
case 'a':
addi(row.getComplex(0));
break;
case 's':
subi(row.getComplex(0));
break;
case 'm':
muli(row.getComplex(0));
break;
case 'd':
divi(row.getComplex(0));
break;
case 'h':
rsubi(row.getComplex(0));
break;
case 't':
rdivi(row.getComplex(0));
break;
}
} else {
switch (operation) {
case 'a':
addi(vector.getDouble(0));
break;
case 's':
subi(vector.getDouble(0));
break;
case 'm':
muli(vector.getDouble(0));
break;
case 'd':
divi(vector.getDouble(0));
break;
case 'h':
rsubi(vector.getDouble(0));
break;
case 't':
rdivi(vector.getDouble(0));
break;
}
}
}
protected DataBuffer shapeOf() {
// if (shape == null)
// shape = Shape.shapeOf(shapeInfoDataBuffer());
// return shape;
return Shape.shapeOf(shapeInfoDataBuffer());
}
protected DataBuffer strideOf() {
// if (stride == null)
// stride = Shape.stride(shapeInfoDataBuffer());
// return stride;
return Shape.stride(shapeInfoDataBuffer());
}
@Override
public int stride(int dimension) {
int rank = Shape.rank(javaShapeInformation);
if (dimension < 0)
return strideOf().getInt(dimension + rank);
return strideOf().getInt(dimension);
}
@Override
public INDArray rdiviColumnVector(INDArray columnVector) {
return doColumnWise(columnVector, 't');
}
@Override
public INDArray rdivColumnVector(INDArray columnVector) {
return dup().rdiviColumnVector(columnVector);
}
@Override
public INDArray rdiviRowVector(INDArray rowVector) {
return doRowWise(rowVector, 't');
}
@Override
public INDArray rdivRowVector(INDArray rowVector) {
return dup().rdiviRowVector(rowVector);
}
@Override
public INDArray rsubiColumnVector(INDArray columnVector) {
return doColumnWise(columnVector, 'h');
}
@Override
public INDArray rsubColumnVector(INDArray columnVector) {
return dup().rsubiColumnVector(columnVector);
}
@Override
public INDArray rsubiRowVector(INDArray rowVector) {
return doRowWise(rowVector, 'h');
}
@Override
public INDArray rsubRowVector(INDArray rowVector) {
return dup().rsubiRowVector(rowVector);
}
/**
* Inserts the element at the specified index
*
* @param i the index insert into
* @param element a scalar ndarray
* @return a scalar ndarray of the element at this index
*/
@Override
public INDArray put(int i, INDArray element) {
if (!element.isScalar())
throw new IllegalArgumentException("Element must be a scalar");
return putScalar(i, element.getDouble(0));
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray diviColumnVector(INDArray columnVector) {
return doColumnWise(columnVector, 'd');
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray divColumnVector(INDArray columnVector) {
return dup().diviColumnVector(columnVector);
}
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override
public INDArray diviRowVector(INDArray rowVector) {
return doRowWise(rowVector, 'd');
}
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override
public INDArray divRowVector(INDArray rowVector) {
return dup().diviRowVector(rowVector);
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray muliColumnVector(INDArray columnVector) {
return doColumnWise(columnVector, 'm');
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray mulColumnVector(INDArray columnVector) {
return dup().muliColumnVector(columnVector);
}
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override
public INDArray muliRowVector(INDArray rowVector) {
return doRowWise(rowVector, 'm');
}
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override
public INDArray mulRowVector(INDArray rowVector) {
return dup().muliRowVector(rowVector);
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray subiColumnVector(INDArray columnVector) {
return doColumnWise(columnVector, 's');
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray subColumnVector(INDArray columnVector) {
return dup().subiColumnVector(columnVector);
}
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override
public INDArray subiRowVector(INDArray rowVector) {
return doRowWise(rowVector, 's');
}
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override
public INDArray subRowVector(INDArray rowVector) {
return dup().subiRowVector(rowVector);
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray addiColumnVector(INDArray columnVector) {
return doColumnWise(columnVector, 'a');
}
@Override
public INDArray putiColumnVector(INDArray columnVector) {
return doColumnWise(columnVector, 'p');
}
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override
public INDArray addColumnVector(INDArray columnVector) {
return dup().addiColumnVector(columnVector);
}
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override
public INDArray addiRowVector(INDArray rowVector) {
return doRowWise(rowVector, 'a');
}
@Override
public INDArray putiRowVector(INDArray rowVector) {
return doRowWise(rowVector, 'p');
}
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override
public INDArray addRowVector(INDArray rowVector) {
return dup().addiRowVector(rowVector);
}
/**
* Perform a copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmul(INDArray other, INDArray result,MMulTranspose mMulTranspose) {
MMulTranspose mMulTranspose1 = MMulTranspose.builder()
.a(this)
.b(other)
.transposeA(mMulTranspose.isTransposeA())
.transposeB(mMulTranspose.isTransposeB())
.transposeResult(mMulTranspose.isTransposeResult())
.build();
return mMulTranspose1.getA().mmul(mMulTranspose1.getB(),result);
}
/**
* Perform a copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
MMulTranspose mMulTranspose1 = MMulTranspose.builder()
.a(this)
.b(other)
.transposeA(mMulTranspose.isTransposeA())
.transposeB(mMulTranspose.isTransposeB())
.transposeResult(mMulTranspose.isTransposeResult())
.build();
return mMulTranspose1.getA().mmul(mMulTranspose1.getB());
}
/**
* Perform a copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmul(INDArray other) {
// FIXME: for 1D case, we probably want vector output here?
int[] shape = {rows(), other.rank() == 1 ? 1 : other.columns()};
INDArray result = createUninitialized(shape, 'f');
if (result.isScalar())
return Nd4j.scalar(Nd4j.getBlasWrapper().dot(this, other));
return mmuli(other, result);
}
protected INDArray create(int[] shape, char ordering) {
if (this instanceof IComplexNDArray)
return Nd4j.createComplex(shape, ordering);
else
return Nd4j.create(shape, ordering);
}
@Override
public double[][] toDoubleMatrix() {
if(!isMatrix()) {
throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix!");
}
double[][] ret = new double[rows()][columns()];
for(int i = 0; i < ret.length; i++) {
ret[i] = getRow(i).dup().data().asDouble();
}
return ret;
}
@Override
public double[] toDoubleVector() {
if(!isVector()) {
throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector!");
}
return dup().data().asDouble();
}
@Override
public float[] toFloatVector() {
if(!isVector()) {
throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector!");
}
return dup().data().asFloat();
}
@Override
public float[][] toFloatMatrix() {
if(!isMatrix()) {
throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix!");
}
float[][] ret = new float[rows()][columns()];
for(int i = 0; i < ret.length; i++) {
ret[i] = getRow(i).dup().data().asFloat();
}
return ret;
}
@Override
public int[] toIntVector() {
if(!isVector()) {
throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector!");
}
return dup().data().asInt();
}
@Override
public int[][] toIntMatrix() {
if(!isMatrix()) {
throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix!");
}
int[][] ret = new int[rows()][columns()];
for(int i = 0; i < ret.length; i++) {
ret[i] = getRow(i).dup().data().asInt();
}
return ret;
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmul(INDArray other, INDArray result) {
return mmuli(other, result);
}
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override
public INDArray div(INDArray other) {
return divi(other, Nd4j.createUninitialized(this.shape(), this.ordering()));
}
/**
* copy (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override
public INDArray div(INDArray other, INDArray result) {
return divi(other, result);
}
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the addition
*/
@Override
public INDArray mul(INDArray other) {
return muli(other, Nd4j.createUninitialized(this.shape(), this.ordering()));
}
/**
* copy (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override
public INDArray mul(INDArray other, INDArray result) {
return muli(other, result);
}
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override
public INDArray sub(INDArray other) {
return subi(other, Nd4j.createUninitialized(other.shape(), other.ordering()));
}
/**
* copy subtraction of two matrices
*
* @param other the second ndarray to subtract
* @param result the result ndarray
* @return the result of the subtraction
*/
@Override
public INDArray sub(INDArray other, INDArray result) {
return subi(other, result);
}
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override
public INDArray add(INDArray other) {
return dup().addi(other);
}
/**
* copy addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override
public INDArray add(INDArray other, INDArray result) {
return dup().addi(other, result);
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param transpose the transpose status of each ndarray
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other, MMulTranspose transpose) {
return dup().mmuli(other, this,transpose);
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other) {
return dup().mmuli(other, this);
}
/**
* Perform an in place matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
MMulTranspose mMulTranspose = MMulTranspose.builder()
.a(this)
.b(other)
.transposeA(transpose.isTransposeA())
.transposeB(transpose.isTransposeB())
.transposeResult(transpose.isTransposeResult())
.build();
return mMulTranspose.getA().mmuli(mMulTranspose.getB(),result);
}
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @return the result of the matrix multiplication
*/
@Override
public INDArray mmuli(INDArray other, INDArray result) {
LinAlgExceptions.assertMultiplies(this, other);
if (other.isScalar()) {
return muli(other.getDouble(0), result);
}
if (isScalar()) {
return other.muli(getDouble(0), result);
}
/* check sizes and resize if necessary */
if (result == this || result == other) {
/* actually, blas cannot do multiplications in-place. Therefore, we will fake by
* allocating a temporary object on the side and copy the result later.
*/
INDArray temp = create(result.shape(), Nd4j.getStrides(result.shape(), 'f'));
temp.setOrder('f');
if (other.columns() == 1 || other.rank() == 1) {
Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(result),
BlasBufferUtil.getCharForTranspose(this), 1.0, this, other, 0.0, temp);
}
else {
Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(result),
BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(temp), 1.0,
this, other, 0.0, temp);
Nd4j.getBlasWrapper().level3().gemm(
BlasBufferUtil.getCharForTranspose(result),
BlasBufferUtil.getCharForTranspose(this),
BlasBufferUtil.getCharForTranspose(temp),
1.0,
this,
other,
0.0,
temp);
}
result.assign(temp);
} else {
//We require that the result array is 'f' (fortran) order
// However, user might have called mmuli with a c order array for the result
// In which case, we need to allocate a temporary f order array, and later do an assign to the real result array
boolean requiresTemp = result.ordering() == 'c';
INDArray gemmResultArr;
if (requiresTemp) {
//Can use createUninitialized due to beta==0.0 parameter in gemm
gemmResultArr = Nd4j.createUninitialized(result.shape(), 'f');
} else {
gemmResultArr = result;
}
if (other.columns() == 1 || other.rank() == 1) {
Nd4j.getBlasWrapper().level2().gemv(
ordering(),
BlasBufferUtil.getCharForTranspose(other),
1.0,
this,
other,
0.0,
gemmResultArr);
} else {
//gemm doesn't support strides so vectors and views
//don't work
Nd4j.getBlasWrapper().level3().gemm(ordering(),
BlasBufferUtil.getCharForTranspose(other),
BlasBufferUtil.getCharForTranspose(gemmResultArr),
1.0,
this,
other,
0.0,
gemmResultArr);
}
if (requiresTemp) {
result.assign(gemmResultArr);
}
}
// 1D edge case: reshape back to vector
if (other.rank() == 1)
result = result.reshape(result.length());
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
private INDArray create(int[] shape, int[] stride) {
if (this instanceof IComplexNDArray)
return Nd4j.createComplex(shape, stride);
else
return Nd4j.create(shape, stride);
}
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @return the result of the divide
*/
@Override
public INDArray divi(INDArray other) {
return divi(other, this);
}
/**
* in place (element wise) division of two matrices
*
* @param other the second ndarray to divide
* @param result the result ndarray
* @return the result of the divide
*/
@Override
public INDArray divi(INDArray other, INDArray result) {
if (other.isScalar()) {
return divi(other.getDouble(0), result);
}
if (isScalar()) {
return other.rdivi(getDouble(0), result);
}
if(!Shape.shapeEquals(this.shape(),other.shape())) {
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
Nd4j.getExecutioner().exec(new BroadcastDivOp(this,other,result,broadcastDimensions),broadcastDimensions);
return result;
}
LinAlgExceptions.assertSameShape(other, result);
Nd4j.getExecutioner().exec(new OldDivOp(this, other, result, length()));
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @return the result of the multiplication
*/
@Override
public INDArray muli(INDArray other) {
return muli(other, this);
}
/**
* in place (element wise) multiplication of two matrices
*
* @param other the second ndarray to multiply
* @param result the result ndarray
* @return the result of the multiplication
*/
@Override
public INDArray muli(INDArray other, INDArray result) {
if (other.isScalar()) {
return muli(other.getDouble(0), result);
}
if (isScalar()) {
return other.muli(getDouble(0), result);
}
if(!Shape.shapeEquals(this.shape(),other.shape())) {
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
Nd4j.getExecutioner().exec(new BroadcastMulOp(this,other,result,broadcastDimensions),broadcastDimensions);
return result;
}
LinAlgExceptions.assertSameShape(other, result);
Nd4j.getExecutioner().exec(new OldMulOp(this, other, result, length()));
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
/**
* in place subtraction of two matrices
*
* @param other the second ndarray to subtract
* @return the result of the addition
*/
@Override
public INDArray subi(INDArray other) {
return subi(other, this);
}
/**
* in place subtraction of two matrices
*
* @param other the second ndarray to subtract
* @param result the result ndarray
* @return the result of the subtraction
*/
@Override
public INDArray subi(INDArray other, INDArray result) {
if (other.isScalar()) {
return subi(other.getDouble(0), result);
}
if (isScalar()) {
return other.rsubi(getDouble(0), result);
}
if(!Shape.shapeEquals(this.shape(),other.shape())) {
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
Nd4j.getExecutioner().exec(new BroadcastSubOp(this,other,result,broadcastDimensions),broadcastDimensions);
return result;
}
LinAlgExceptions.assertSameShape(other, result);
Nd4j.getExecutioner().exec(new OldSubOp(this, other,result));
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @return the result of the addition
*/
@Override
public INDArray addi(INDArray other) {
return addi(other, this);
}
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override
public INDArray addi(INDArray other, INDArray result) {
if (other.isScalar()) {
return result.addi(other.getDouble(0), result);
}
if (isScalar()) {
return other.addi(getDouble(0), result);
}
if(!Shape.shapeEquals(this.shape(),other.shape())) {
int[] broadcastDimensions = Shape.getBroadcastDimensions(this.shape(),other.shape());
result = Nd4j.createUninitialized(Shape.broadcastOutputShape(this.shape(),other.shape()));
Nd4j.getExecutioner().exec(new BroadcastAddOp(this,other,result,broadcastDimensions),broadcastDimensions);
return result;
}
LinAlgExceptions.assertSameShape(other, result);
Nd4j.getExecutioner().exec(new OldAddOp(this, other, result, length()));
if (Nd4j.ENFORCE_NUMERICAL_STABILITY)
Nd4j.clearNans(result);
return result;
}
/**
* Returns the normmax along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override
public INDArray normmax(int... dimension) {
return Nd4j.getExecutioner().exec(new NormMax(this), dimension);
}
/**
* Reverse division
*
* @param other the matrix to divide from
* @return
*/
@Override
public INDArray rdiv(INDArray other) {
return dup().rdivi(other);
}
/**
* Reverse divsion (in place)
*
* @param other
* @return
*/
@Override
public INDArray rdivi(INDArray other) {
return rdivi(other, this);
}
/**
* Reverse division
*
* @param other the matrix to subtract from
* @param result the result ndarray
* @return
*/
@Override
public INDArray rdiv(INDArray other, INDArray result) {
return dup().rdivi(other, result);
}
/**
* Reverse division (in-place)
*
* @param other the other ndarray to subtract
* @param result the result ndarray
* @return the ndarray with the operation applied
*/
@Override
public INDArray rdivi(INDArray other, INDArray result) {
return other.divi(this, result);
}
/**
* Reverse subtraction
*
* @param other the matrix to subtract from
* @param result the result ndarray
* @return
*/
@Override
public INDArray rsub(INDArray other, INDArray result) {
return dup().rsubi(other, result);
}
/**
* @param other
* @return
*/
@Override
public INDArray rsub(INDArray other) {
return dup().rsubi(other);
}
/**
* @param other
* @return
*/
@Override
public INDArray rsubi(INDArray other) {
return rsubi(other, this);
}
/**
* Reverse subtraction (in-place)
*
* @param other the other ndarray to subtract
* @param result the result ndarray
* @return the ndarray with the operation applied
*/
@Override
public INDArray rsubi(INDArray other, INDArray result) {
return other.subi(this, result);
}
/**
* Set the value of the ndarray to the specified value
*
* @param value the value to assign
* @return the ndarray with the values
*/
@Override
public INDArray assign(Number value) {
Nd4j.getExecutioner().exec(new ScalarSet(this, value));
return this;
}
/**
* Assign all elements from given ndarray that are matching given condition,
* ndarray to this ndarray
*
* @param arr the elements to assign
* @param condition
* @return this
*/
@Override
public INDArray assignIf(INDArray arr, Condition condition) {
BooleanIndexing.assignIf(this, arr, condition);
return this;
}
/**
* Replaces all elements in this ndarray that are matching give condition, with corresponding elements from given array
*
* @param arr
* @param condition
* @return
*/
@Override
public INDArray replaceWhere(INDArray arr, Condition condition) {
Nd4j.getCompressor().autoDecompress(this);
BooleanIndexing.replaceWhere(this, arr, condition);
return this;
}
@Override
@Deprecated
public int linearIndex(int i) {
int idx = i;
for (int j = 0; j < Shape.rank(javaShapeInformation) - 1; j++) {
if (size(i) == 1)
continue;
idx += i * stride(j);
}
return Shape.offset(javaShapeInformation) + (idx);
}
/**
* Returns the specified slice of this matrix.
* In matlab, this would be equivalent to (given a 2 x 2 x 2):
* A(:,:,x) where x is the slice you want to return.
*
* The slice is always relative to the final dimension of the matrix.
*
* @param slice the slice to return
* @return the specified slice of this matrix
*/
@Override
public INDArray slice(int slice) {
Nd4j.getCompressor().autoDecompress(this);
int slices = slices();
if (slice >= slices)
throw new IllegalArgumentException("Illegal slice " + slice);
if (Shape.rank(javaShapeInformation) == 0 || isVector()) {
if (slice == 0 || isVector()) {
return createScalarForIndex(slice, true);
}
else {
throw new IllegalArgumentException("Can't slice a 0-d NDArray");
}
}
if (slice < 0)
slice += rank();
INDArrayIndex[] indexes = new INDArrayIndex[rank()];
indexes[0] = NDArrayIndex.point(slice);
for (int i = 1; i < rank(); i++) {
indexes[i] = NDArrayIndex.all();
}
return get(indexes);
}
protected INDArray createScalarForIndex(int i, boolean applyOffset) {
if(isVector())
return getScalar(i);
return create(data(), new int[] {1, 1}, new int[] {1, 1}, i);
}
protected INDArray createScalar(double d) {
return Nd4j.scalar(d);
}
@Override
public int getTrailingOnes() {
int numLeadingOnes = 0;
for (int i = rank() - 1; i > 0; i--) {
if (size(i) == 1)
numLeadingOnes++;
}
return numLeadingOnes;
}
@Override
public int getLeadingOnes() {
int numLeadingOnes = 0;
for (int i = 0; i < rank(); i++) {
if (size(i) == 1)
numLeadingOnes++;
}
return numLeadingOnes;
}
/**
* Returns the slice of this from the specified dimension
*
* @param slice the dimension to return from
* @param dimension the dimension of the slice to return
* @return the slice of this matrix from the specified dimension
* and dimension
*/
@Override
public INDArray slice(int slice, int dimension) {
Nd4j.getCompressor().autoDecompress(this);
int slices = size(dimension);
if (slice >= slices)
throw new IllegalArgumentException("Illegal slice " + slice);
if (Shape.rank(javaShapeInformation) == 0) {
if (slice == 0)
return createScalarForIndex(slice, true);
else
throw new IllegalArgumentException("Can't slice a 0-d NDArray");
}
if (slice < 0)
slice += rank();
INDArrayIndex[] indexes = new INDArrayIndex[rank()];
indexes[dimension] = NDArrayIndex.point(slice);
for (int i = 0; i < rank(); i++) {
if (i != dimension)
indexes[i] = NDArrayIndex.all();
}
return get(indexes);
}
/**
* Fetch a particular number on a multi dimensional scale.
*
* @param indexes the indexes to get a number from
* @return the number at the specified indices
*/
@Override
public INDArray getScalar(int... indexes) {
return Nd4j.scalar(getDouble(indexes));
}
@Override
public INDArray rdiv(Number n) {
//return dup().rdivi(n);
return rdivi(n, Nd4j.createUninitialized(this.shape(), this.ordering()));
}
@Override
public INDArray rdivi(Number n) {
return rdivi(n, this);
}
@Override
public INDArray rsub(Number n) {
//return dup().rsubi(n);
return rsubi(n, Nd4j.createUninitialized(this.shape(), this.ordering()));
}
@Override
public INDArray rsubi(Number n) {
return rsubi(n, this);
}
@Override
public INDArray div(Number n) {
//return dup().divi(n);
return divi(n, Nd4j.createUninitialized(this.shape(), this.ordering()));
}
@Override
public INDArray divi(Number n) {
return divi(n, this);
}
@Override
public INDArray mul(Number n) {
// return dup().muli(n);
return muli(n, Nd4j.createUninitialized(this.shape(), this.ordering()));
}
@Override
public INDArray muli(Number n) {
return muli(n, this);
}
@Override
public INDArray sub(Number n) {
//return dup().subi(n);
return subi(n, Nd4j.createUninitialized(this.shape(), this.ordering()));
}
@Override
public INDArray subi(Number n) {
return subi(n, this);
}
@Override
public INDArray add(Number n) {
//return dup().addi(n);
return addi(n, Nd4j.createUninitialized(this.shape(), this.ordering()));
}
@Override
public INDArray addi(Number n) {
return addi(n, this);
}
/**
* Replicate and tile array to fill out to the given shape
* See:
* https://github.com/numpy/numpy/blob/master/numpy/matlib.py#L310-L358
* @param shape the new shape of this ndarray
* @return the shape to fill out to
*/
@Override
public INDArray repmat(int[] shape) {
Nd4j.getCompressor().autoDecompress(this);
int rows = rows() * shape[0];
int cols = columns() * shape[1];
INDArray ret = reshape(1, length()).repeat(0, shape[0]).reshape(rows, columns()).repeat(0, shape[1]);
return ret.reshape(rows, cols);
}
@Override
public INDArray repeat(int dimension, int... repeats) {
Nd4j.getCompressor().autoDecompress(this);
if (dimension < 0)
dimension += rank();
if (repeats.length < rank()) {
if (dimension > 0)
repeats = Ints.concat(ArrayUtil.nTimes(rank() - repeats.length, 1), repeats);
//append rather than prepend for dimension == 0
else
repeats = Ints.concat(repeats, ArrayUtil.nTimes(rank() - repeats.length, 1));
}
int[] newShape = new int[rank()];
for (int i = 0; i < newShape.length; i++)
newShape[i] = size(i) * repeats[i];
INDArray ret = create(newShape);
//number of times to repeat each value
int repeatDelta = ArrayUtil.prod(newShape) / length();
for (int i = 0; i < tensorssAlongDimension(dimension); i++) {
INDArray thisTensor = tensorAlongDimension(i, dimension);
INDArray retTensor = ret.tensorAlongDimension(i, dimension);
int retIdx = 0;
for (int k = 0; k < thisTensor.length(); k++) {
for (int j = 0; j < repeatDelta; j++) {
retTensor.putScalar(retIdx++, thisTensor.getDouble(k));
}
}
}
return ret;
}
/**
* Insert a row in to this array
* Will throw an exception if this
* ndarray is not a matrix
*
* @param row the row insert into
* @param toPut the row to insert
* @return this
*/
@Override
public INDArray putRow(int row, INDArray toPut) {
if (isRowVector() && toPut.isVector()) {
return assign(toPut);
}
return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut);
}
/**
* Insert a column in to this array
* Will throw an exception if this
* ndarray is not a matrix
*
* @param column the column to insert
* @param toPut the array to put
* @return this
*/
@Override
public INDArray putColumn(int column, INDArray toPut) {
Nd4j.getCompressor().autoDecompress(this);
if (isColumnVector() && toPut.isVector()) {
return assign(toPut);
}
return put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(column)}, toPut);
}
@Override
public double getDouble(int i) {
Nd4j.getCompressor().autoDecompress(this);
if (i >= length()) {
throw new IllegalArgumentException("Unable to get linear index >= " + length());
}
autoProcessScalarCall();
if (i == 0)
return data().getDouble(i);
int[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
Shape.assertShapeLessThan(dimensions, shape());
return getDouble(dimensions);
}
@Override
public double getDouble(int i, int j) {
return getDouble(new int[] {i, j});
}
@Override
public float getFloat(int i) {
return (float) getDouble(i);
}
@Override
public float getFloat(int i, int j) {
return (float) getDouble(i, j);
}
/**
* Return transposed copy of this matrix.
*/
@Override
public INDArray transpose() {
return transposei();
}
/**
*
* Return transposed version of this matrix.
*
* PLEASE NOTE: This method is NOT in place, it will return transposed copy instead.
*/
@Override
public INDArray transposei() {
return permute(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank())));
}
protected INDArray create(DataBuffer data, int[] shape, int[] strides) {
if (this instanceof IComplexNDArray)
return Nd4j.createComplex(data, shape, strides, 0, ordering());
else
return Nd4j.create(data, shape, strides, 0, ordering());
}
@Override
public INDArray reshape(char order, int... newShape) {
Nd4j.getCompressor().autoDecompress(this);
if (newShape == null || newShape.length < 1)
throw new ND4JIllegalStateException(
"Can't reshape(int...) without shape arguments. Got empty shape instead.");
// TODO: maybe toFlatten() makes more sense here?
// reshape(-1) special case
if (newShape.length == 1 && newShape[0] == -1)
newShape[0] = this.length();
int numberNegativesOnes = 0;
int[] shape = ArrayUtil.copy(newShape);
for (int i = 0; i < shape.length; i++) {
if (shape[i] < 0) {
if (numberNegativesOnes >= 1)
throw new IllegalArgumentException("Only one dimension can be negative ones. Got shape "
+ Arrays.toString(newShape));
numberNegativesOnes++;
int shapeLength = 1;
for (int j = 0; j < shape.length; j++)
if (shape[j] >= 1)
shapeLength *= shape[j];
int realShape = Math.abs(length() / shapeLength);
int[] thisNewShape = new int[shape.length];
for (int j = 0; j < shape.length; j++) {
if (i != j) {
thisNewShape[j] = shape[j];
} else
thisNewShape[j] = realShape;
}
shape = thisNewShape;
break;
}
}
long prod = ArrayUtil.prodLong(shape);
if (prod != this.lengthLong()){
throw new ND4JIllegalStateException("New shape length doesn't match original length: [" + prod + "] vs [" + this.lengthLong() + "]. Original shape: "+Arrays.toString(this.shape())+" New Shape: "+Arrays.toString(newShape));
}
INDArray reshapeAttempt = Shape.newShapeNoCopy(this, shape, order == 'f');
if (reshapeAttempt != null) {
// kinda strange get/set usage
// reshapeAttempt.setOrder(Shape.getOrder(reshapeAttempt));
return reshapeAttempt;
}
INDArray ret = Nd4j.createUninitialized(shape, order);
if (order != ordering()) {
ret.setData(dup(order).data());
} else
ret.assign(this);
return ret;
}
@Override
public double getDoubleUnsafe(long offset) {
return data().getDouble(offset);
}
@Override
public INDArray putScalarUnsafe(long offset, double value) {
autoProcessScalarCall();
data().put(offset, value);
return this;
}
@Override
public int innerMostStride() {
if (ordering() == 'c')
return stride(-1);
return stride(0);
}
@Override
public INDArray reshape(char order, int rows, int columns) {
return reshape(order, new int[] {rows, columns});
}
/**
* Reshape the ndarray in to the specified dimensions,
* possible errors being thrown for invalid shapes
*
* Note here that one dimension can be -1.
* The dimension that is -1 will be inferred from the shape and
* the length of the ndarray
*
* @param shape the shape of the ndarray.
* @return the new reshaped nd array
*/
@Override
public INDArray reshape(int... shape) {
return reshape(Nd4j.order(), shape);
}
@Override
public void checkDimensions(INDArray other) {
assert Shape.contentEquals(other.shape(),
Shape.shapeOf(shapeInformation)) : " Other array should have been shape: "
+ Shape.toString(Shape.shapeOf(shapeInformation)) + " but was "
+ Arrays.toString(other.shape());
assert Shape.contentEquals(other.stride(),
Shape.stride(shapeInformation)) : " Other array should have been stride: "
+ Shape.toString(Shape.stride(shapeInformation)) + " but was "
+ Arrays.toString(other.stride());
assert Shape.offset(javaShapeInformation) == other.offset() : "Offset of this array is "
+ Shape.offset(javaShapeInformation) + " but other was " + other.offset();
}
/**
* Returns the product along a given dimension
*
* @param dimension the dimension to getScalar the product along
* @return the product along the specified dimension
*/
@Override
public INDArray prod(int... dimension) {
return Nd4j.getExecutioner().exec(new Prod(this), dimension);
}
/**
* Returns the overall mean of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray mean(int... dimension) {
return Nd4j.getExecutioner().exec(new Mean(this), dimension);
}
@Override
public INDArray amean(int... dimension) {
return Nd4j.getExecutioner().exec(new AMean(this), dimension);
}
@Override
public INDArray mean(@NonNull INDArray result, int... dimension) {
return Nd4j.getExecutioner().exec(new Mean(this, null, result), dimension);
}
/**
* Returns the overall variance of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray var(int... dimension) {
return Nd4j.getExecutioner().exec(new Variance(this), dimension);
}
/**
* Returns the overall variance of this ndarray
*
* @param biasCorrected boolean on whether to apply corrected bias
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray var(boolean biasCorrected, int... dimension) {
return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected), dimension);
}
/**
* Returns the overall max of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray max(int... dimension) {
return Nd4j.getExecutioner().exec(new Max(this), dimension);
}
@Override
public INDArray amax(int... dimension) {
return Nd4j.getExecutioner().exec(new AMax(this), dimension);
}
/**
* Returns the overall min of this ndarray
*
* @param dimension the dimension to getScalar the mean along
* @return the mean along the specified dimension of this ndarray
*/
@Override
public INDArray min(int... dimension) {
return Nd4j.getExecutioner().exec(new Min(this), dimension);
}
@Override
public INDArray amin(int... dimension) {
return Nd4j.getExecutioner().exec(new AMin(this), dimension);
}
/**
* Returns the sum along the last dimension of this ndarray
*
* @param dimension the dimension to getScalar the sum along
* @return the sum along the specified dimension of this ndarray
*/
@Override
public INDArray sum(int... dimension) {
return Nd4j.getExecutioner().exec(new Sum(this), dimension);
}
/**
* Returns entropy along dimension
* @param dimension
* @return
*/
@Override
public INDArray entropy(int... dimension) {
return Nd4j.getExecutioner().exec(new Entropy(this), dimension);
}
/**
* Returns non-normalized Shannon entropy along dimension
* @param dimension
* @return
*/
@Override
public INDArray shannonEntropy(int... dimension) {
return Nd4j.getExecutioner().exec(new ShannonEntropy(this), dimension);
}
/**
* Returns log entropy along dimension
* @param dimension
* @return
*/
@Override
public INDArray logEntropy(int... dimension) {
return Nd4j.getExecutioner().exec(new LogEntropy(this), dimension);
}
@Override
public INDArray sum(@NonNull INDArray result, int... dimension) {
return Nd4j.getExecutioner().exec(new Sum(this, null, result), dimension);
}
/**
* Returns the norm1 along the specified dimension
*
* @param dimension the dimension to getScalar the norm1 along
* @return the norm1 along the specified dimension
*/
@Override
public INDArray norm1(int... dimension) {
return Nd4j.getExecutioner().exec(new Norm1(this), dimension);
}
/**
* Standard deviation of an ndarray along a dimension
*
* @param dimension the dimension to getScalar the std along
* @return the standard deviation along a particular dimension
*/
@Override
public INDArray std(int... dimension) {
return Nd4j.getExecutioner().exec(new StandardDeviation(this), dimension);
}
@Override
public INDArray std(boolean biasCorrected, int... dimension) {
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected),
dimension);
}
@Override
public Number stdNumber(boolean biasCorrected) {
return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected),
new int[] {Integer.MAX_VALUE})
.getDouble(0);
}
/**
* Returns the norm2 along the specified dimension
*
* @param dimension the dimension to getScalar the norm2 along
* @return the norm2 along the specified dimension
*/
@Override
public INDArray norm2(int... dimension) {
return Nd4j.getExecutioner().exec(new Norm2(this), dimension);
}
/**
* Number of columns (shape[1]), throws an exception when
* called when not 2d
*
* @return the number of columns in the array (only 2d)
*/
@Override
public int columns() {
if (isMatrix())
return size(1);
else if (Shape.isColumnVectorShape(shape())) {
return 1;
} else if (Shape.isRowVectorShape(shape())) {
return length();
}
throw new IllegalStateException("Rank is " + rank() + " columns() call is not valid");
}
/**
* Returns the number of rows
* in the array (only 2d) throws an exception when
* called when not 2d
*
* @return the number of rows in the matrix
*/
@Override
public int rows() {
if (isMatrix())
return size(0);
else if (Shape.isRowVectorShape(shape())) {
return 1;
} else if (Shape.isColumnVectorShape(shape())) {
return length();
}
throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid");
}
/**
* Flattens the array for linear indexing
*
* @return the flattened version of this array
*/
@Override
public INDArray ravel(char ordering) {
Nd4j.getCompressor().autoDecompress(this);
if (length() >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Length can not be >= Integer.MAX_VALUE");
INDArray ret = create(new int[] {1, (int) length()}, ordering);
NDArrayIndex index = new NDArrayIndex(this.shape());
for (int i = 0; i < length(); i++) {
// FIXME: LONG
double val = getDouble((int) index.next());
ret.putScalar(new int[] {0, i}, val);
}
return ret;
}
/**
* Flattens the array for linear indexing
*
* @return the flattened version of this array
*/
@Override
public INDArray ravel() {
return reshape(1, length());
}
/**
* Flattens the array for linear indexing
*
* @return the flattened version of this array
*/
@Override
public void sliceVectors(List list) {
if (isVector())
list.add(this);
else {
for (int i = 0; i < slices(); i++) {
slice(i).sliceVectors(list);
}
}
}
/**
* Reshape the matrix. Number of elements must not change.
*
* @param newRows
* @param newColumns
*/
@Override
public INDArray reshape(int newRows, int newColumns) {
return reshape(new int[] {newRows, newColumns});
}
/**
* Get the specified column
*
* @param c
*/
@Override
public INDArray getColumn(int c) {
Nd4j.getCompressor().autoDecompress(this);
if (isColumnVector() && c == 0)
return this;
else if (isColumnVector() && c > 0)
throw new IllegalArgumentException("Illegal index for row");
else if(isRowVector()) {
return Nd4j.scalar(getDouble(c));
}
return get(NDArrayIndex.all(), NDArrayIndex.point(c));
}
/**
* Get whole rows from the passed indices.
*
* @param rindices
*/
@Override
public INDArray getRows(int[] rindices) {
Nd4j.getCompressor().autoDecompress(this);
if (!isMatrix() && !isVector())
throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
if (isVector())
return Nd4j.pullRows(this, 1, rindices);
else {
INDArray ret = Nd4j.create(rindices.length, columns());
for (int i = 0; i < rindices.length; i++)
ret.putRow(i, getRow(rindices[i]));
return ret;
}
}
/**
* Returns a subset of this array based on the specified
* indexes
*
* @param indexes the indexes in to the array
* @return a view of the array with the specified indices
*/
@Override
public INDArray get(INDArrayIndex... indexes) {
Nd4j.getCompressor().autoDecompress(this);
if(indexes.length > rank()) {
int numNonNewAxis = 0;
for(int i = 0; i < indexes.length; i++) {
if(!(indexes[i] instanceof NewAxis))
numNonNewAxis++;
}
if(numNonNewAxis > rank()) {
throw new IllegalArgumentException("Too many indices for array. Number of indexes must be <= rank()");
}
}
//check for row/column vector and point index being 0
if (indexes.length == 1 && indexes[0] instanceof NDArrayIndexAll || (indexes.length == 2 && (isRowVector()
&& indexes[0] instanceof PointIndex && indexes[0].offset() == 0
&& indexes[1] instanceof NDArrayIndexAll
|| isColumnVector() && indexes[1] instanceof PointIndex && indexes[0].offset() == 0
&& indexes[0] instanceof NDArrayIndexAll)))
return this;
indexes = NDArrayIndex.resolve(shapeInfoDataBuffer(), indexes);
ShapeOffsetResolution resolution = new ShapeOffsetResolution(this);
resolution.exec(indexes);
if (indexes.length < 1)
throw new IllegalStateException("Invalid index found of zero length");
// FIXME: LONG
int[] shape = LongUtils.toInts(resolution.getShapes());
int numSpecifiedIndex = 0;
for (int i = 0; i < indexes.length; i++)
if (indexes[i] instanceof SpecifiedIndex)
numSpecifiedIndex++;
if (shape != null && numSpecifiedIndex > 0) {
Generator>> gen = SpecifiedIndex.iterate(indexes);
INDArray ret = Nd4j.create(shape, 'c');
int count = 0;
while (true) {
try {
List> next = gen.next();
List coordsCombo = new ArrayList<>();
for (int i = 0; i < next.size(); i++) {
if (next.get(i).size() > 1)
throw new IllegalStateException("Illegal entry returned");
coordsCombo.add(next.get(i).get(0));
}
ret.putScalar(count++, getDouble(Ints.toArray(coordsCombo)));
} catch (NoSuchElementException e) {
break;
}
if (count >= ret.length())
break;
}
return ret;
}
INDArray ret = subArray(resolution);
return ret;
}
/**
* Get whole columns
* from the passed indices.
*
* @param cindices
*/
@Override
public INDArray getColumns(int... cindices) {
if (!isMatrix() && !isVector())
throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
if (isVector()) {
return Nd4j.pullRows(this, 0, cindices, this.ordering());
} else {
INDArray ret = Nd4j.create(rows(), cindices.length);
for (int i = 0; i < cindices.length; i++)
ret.putColumn(i, getColumn(cindices[i]));
return ret;
}
}
protected INDArray create(int rows, int length) {
return create(new int[] {rows, length});
}
/**
* Get a copy of a row.
*
* @param r the row to get
*/
@Override
public INDArray getRow(int r) {
if (isRowVector() && r == 0)
return this;
else if (isRowVector() && r > 0)
throw new IllegalArgumentException("Illegal index for row");
INDArray result = get(NDArrayIndex.point(r), NDArrayIndex.all());
// FIXME: this is bad
if (!this.isView() && this.ordering() == 'c' && result.elementWiseStride() == 1 && result.ordering() != 'c') {
((BaseNDArray) result).setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(result.shape(), result.stride(), 0, 1, 'c'));
}
return result;
}
/**
* This method allows you to compare INDArray against other INDArray, with variable eps
*
* @param o
* @param eps
* @return
*/
public boolean equalsWithEps(Object o, double eps) {
Nd4j.getCompressor().autoDecompress(this);
if (o == null)
return false;
if (!(o instanceof INDArray))
return false;
INDArray n = (INDArray) o;
if (n.isSparse()) {
return n.equals(this);
}
if (this.lengthLong() != n.lengthLong())
return false;
//epsilon equals
if (isScalar() && n.isScalar()) {
if (data.dataType() == DataBuffer.Type.FLOAT) {
double val = getDouble(0);
double val2 = n.getDouble(0);
if (Double.isNaN(val) != Double.isNaN(val2))
return false;
return Math.abs(val - val2) < eps;
} else {
double val = getDouble(0);
double val2 = n.getDouble(0);
if (Double.isNaN(val) != Double.isNaN(val2))
return false;
return Math.abs(val - val2) < eps;
}
} else if (isVector() && n.isVector()) {
EqualsWithEps op = new EqualsWithEps(this, n, eps);
Nd4j.getExecutioner().exec(op);
double diff = op.getFinalResult().doubleValue();
return diff < 0.5;
}
if (!Arrays.equals(this.shape(), n.shape()))
return false;
if (!Shape.shapeEquals(shape(), n.shape())) {
return false;
}
if (slices() != n.slices())
return false;
if (n.ordering() == ordering()) {
EqualsWithEps op = new EqualsWithEps(this, n, eps);
Nd4j.getExecutioner().exec(op);
double diff = op.getFinalResult().doubleValue();
return diff < 0.5;
} else {
EqualsWithEps op = new EqualsWithEps(this, n, eps);
Nd4j.getExecutioner().exec(op);
double diff = op.getFinalResult().doubleValue();
return diff < 0.5;
}
}
/**
* Compare two matrices. Returns true if and only if other is also a
* DoubleMatrix which has the same size and the maximal absolute
* difference in matrix elements is smaller than 1e-5.
*
* @param o
*/
@Override
public boolean equals(Object o) {
return equalsWithEps(o, Nd4j.EPS_THRESHOLD);
}
@Override
public DataBuffer shapeInfoDataBuffer() {
return shapeInformation;
}
@Override
public IntBuffer shapeInfo() {
return shapeInformation.asNioInt();
}
/**
* Returns the shape(dimensions) of this array
*
* @return the shape of this matrix
*/
public int[] shape() {
return Shape.shape(javaShapeInformation);
}
/**
* Returns the shape information debugging
* information
*
* @return the shape information debugging information
*/
@Override
public String shapeInfoToString() {
return Shape.shapeToString(this);
}
/**
* Returns the stride(indices along the linear index for which each slice is accessed) of this array
*
* @return the stride of this array
*/
@Override
public int[] stride() {
return Shape.stride(javaShapeInformation);
}
@Override
public long offset() {
if (data().offset() >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Offset of buffer can not be >= Integer.MAX_VALUE");
// return Shape.offset(shapeInfo());
return data().offset();
}
@Override
public char ordering() {
return Shape.order(javaShapeInformation);
}
/**
* Returns the size of this array
* along a particular dimension
*
* @param dimension the dimension to return from
* @return the shape of the specified dimension
*/
@Override
public int size(int dimension) {
if (isScalar()) {
if (dimension == 0 || dimension == 1 || dimension < 0)
return (int) lengthLong();
else
throw new IllegalArgumentException("Illegal dimension for scalar " + dimension);
}
if (dimension < 0) {
return shapeOf().getInt(dimension + Shape.rank(javaShapeInformation));
}
if (dimension >= rank())
throw new IllegalArgumentException("Invalid size: cannot get size of dimension " + dimension + " for rank "
+ rank() + " NDArray (array shape: " + Arrays.toString(this.shape()) + ")");
return shapeOf().getInt(dimension);
}
@Override
public int rank() {
return Shape.rank(javaShapeInformation);
}
/**
* Returns the total number of elements in the ndarray
*
* @return the number of elements in the ndarray
*/
@Override
public int length() {
return (int) Shape.length(javaShapeInformation);
}
/**
* Returns the total number of elements in the ndarray
*
* @return the number of elements in the ndarray
*/
@Override
public long lengthLong() {
return Shape.length(javaShapeInformation);
}
@Override
public INDArray broadcast(INDArray result) {
Nd4j.getCompressor().autoDecompress(this);
int[] shape = result.shape();
if (Shape.shapeEquals(shape, shape()))
return this;
// if we're on scalar, we can just create new array
if (this.isScalar())
return Nd4j.createUninitialized(shape).assign(this.getDouble(0));
boolean compatible = true;
int count = shape.length - 1;
int thisCount = Shape.rank(javaShapeInformation) - 1;
for (int i = shape.length - 1; i > 0; i--) {
if (count < 0 || thisCount < 0)
break;
if (shape[count] != shape()[thisCount] && shape[count] != 1 && shape()[thisCount] != 1) {
compatible = false;
break;
}
count--;
thisCount--;
}
if (!compatible)
throw new IllegalArgumentException("Incompatible broadcast from " + Arrays.toString(shape()) + " to "
+ Arrays.toString(shape));
int[] retShape = new int[shape.length];
List broadCastDimensions = new ArrayList<>();
List nonBroadCastDimensions = new ArrayList<>();
for (int i = 0; i < retShape.length; i++) {
if (shape().length == 1) {
if (i == 0) {
if (i < shape().length)
retShape[i] = Math.max(1, shape[i]);
else
retShape[i] = shape[i];
} else {
if (i < shape().length)
retShape[i] = Math.max(shape[i], size(i));
else
retShape[i] = shape[i];
}
} else {
if (i < rank() && size(i) == 1)
broadCastDimensions.add(i);
else
nonBroadCastDimensions.add(i);
if (i < shape().length)
retShape[i] = Math.max(shape[i], size(i));
else
retShape[i] = shape[i];
}
}
if (isRowVector()) {
//number of times to repeat each value
for (int i = 0; i < result.slices(); i++) {
result.putSlice(i, this);
}
} else if (isColumnVector()) {
for (int i = 0; i < result.columns(); i++) {
result.putColumn(i, this);
}
}
else {
int[] repeat = new int[shape.length];
for(int i = 0; i < shape.length; i++) {
if(i < rank()) {
if(size(i) == 1)
repeat[i] = shape[i];
else {
repeat[i] = 1;
}
}
else {
repeat[i] = shape[i];
}
}
if (this.isView()) {
Nd4j.getExecutioner().exec(new Tile(new INDArray[]{this.dup(this.ordering())},new INDArray[]{result},repeat));
} else
Nd4j.getExecutioner().exec(new Tile(new INDArray[]{this},new INDArray[]{result},repeat));
//result = Nd4j.tile(this,repeat);
}
return result;
}
/**
* Broadcasts this ndarray to be the specified shape
*
* @param shape the new shape of this ndarray
* @return the broadcasted ndarray
*/
@Override
public INDArray broadcast(int... shape) {
return broadcast(Nd4j.createUninitialized(shape));
}
/**
* Dimshuffle: an extension of permute that adds the ability
* to broadcast various dimensions.
*
* See theano for more examples.
* This will only accept integers and xs.
*
* An x indicates a dimension should be broadcasted rather than permuted.
*
* @param rearrange the dimensions to swap to
* @return the newly permuted array
*/
@Override
public INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable) {
Nd4j.getCompressor().autoDecompress(this);
if (broadCastable.length != Shape.rank(javaShapeInformation))
throw new IllegalArgumentException(
"The broadcastable dimensions must be the same length as the current shape");
boolean broadcast = false;
Set