All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ndarray.BaseNDArray Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.api.ndarray;


import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.primitives.Longs;
import com.google.flatbuffers.FlatBufferBuilder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.graph.ByteOrder;
import org.nd4j.graph.FlatArray;
import org.nd4j.linalg.api.blas.BlasBufferUtil;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.*;
import org.nd4j.linalg.api.iter.FirstAxisIterator;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
import org.nd4j.linalg.api.ops.impl.reduce.same.*;
import org.nd4j.linalg.api.ops.impl.reduce3.EqualsWithEps;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.reduce.same.Max;
import org.nd4j.linalg.api.ops.impl.reduce.same.Min;
import org.nd4j.linalg.api.ops.impl.broadcast.*;
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.scalar.*;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.*;
import org.nd4j.linalg.api.ops.impl.shape.Tile;
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.same.Negative;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.exception.*;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.*;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.api.memory.MemcpyDirection;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.string.NDArrayStrings;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.nd4j.linalg.util.NDArrayMath;
import org.nd4j.linalg.workspace.WorkspaceUtils;

import java.io.*;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

import static org.nd4j.linalg.factory.Nd4j.*;


/**
 * NDArray: (think numpy)
 * 

* A few things of note. *

* An NDArray can have any number of dimensions. *

* An NDArray is accessed via strides. *

* Strides are how to index over * a contiguous block of data. *

* This block of data has 2 orders(as of right now): * fortran and c * * @author Adam Gibson */ @Slf4j public abstract class BaseNDArray implements INDArray, Iterable { private static final long serialVersionUID = 3285982317165542614L; protected transient volatile DataBuffer shapeInformation; protected transient volatile DataBuffer data; //protected transient DataBuffer shape; //protected transient DataBuffer stride; protected transient boolean compressed = false; protected transient boolean released = false; // this field holds jvm copy of shapeInfo protected transient JvmShapeInfo jvmShapeInfo; private static final AtomicLong arrayCounter = new AtomicLong(0); protected transient final long arrayId = arrayCounter.getAndIncrement(); //Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over private static final int[][] tadFinalPermuteDimensions; static { tadFinalPermuteDimensions = new int[32][0]; tadFinalPermuteDimensions[1] = new int[] {1, 0}; //Edge case for 1d tensors: selectively apply to column vectors for (int i = 2; i < 32; i++) { tadFinalPermuteDimensions[i] = new int[i]; for (int k = i - 1, j = 0; k >= 0; k--, j++) tadFinalPermuteDimensions[i][j] = k; } val t =1; } public BaseNDArray() { } @Override public boolean isCompressed() { return compressed; } @Override public void markAsCompressed(boolean reallyCompressed) { this.compressed = reallyCompressed; } /** * * @param buffer */ public BaseNDArray(DataBuffer buffer) { this.data = buffer; if (buffer.length() >= Integer.MAX_VALUE) throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); long[] shape = {1, (int) buffer.length()}; long[] stride = Nd4j.getStrides(shape); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 1, Nd4j.order(), buffer.dataType(), false)); init(shape, stride); } /** * * @param buffer * @param shape * @param stride * @param offset * @param ordering */ public BaseNDArray(DataBuffer buffer, int[] shape, int[] stride, long offset, char ordering) { Shape.assertValidOrder(ordering); this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, buffer.dataType(), false)); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); } public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering) { this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering); } public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, char ordering) { Shape.assertValidOrder(ordering); this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, ews, ordering, buffer.dataType(), false )); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); } public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering, DataType dataType) { this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType); } public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, char ordering, DataType dataType) { this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, ews, ordering, dataType, false)); init(shape, stride); } public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type) { this.data = buffer; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); } public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type, MemoryWorkspace workspace) { this.data = buffer; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); } public BaseNDArray(DataBuffer buffer, DataType dataType, long[] shape, long[] stride, long offset, char ordering) { this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType, false)); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); } /** * Initialize the ndarray as a matrix * with the given data (indices preserved) * @param data */ public BaseNDArray(double[][] data) { this(data, Nd4j.order()); } /** * * @param data * @param ordering */ public BaseNDArray(double[][] data, char ordering) { this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), new int[] {data.length, data[0].length}, Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering); int c = columns(); for (int r = 0; r < rows(); r++) { Preconditions.checkState(data[r].length == c, "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c ); } } /** * Create with the specified shape and buffer * * @param shape the shape * @param buffer the buffer */ public BaseNDArray(int[] shape, DataBuffer buffer) { this.data = buffer; init(shape, Nd4j.getStrides(shape)); } /** * Create this ndarray with the given data and shape and 0 offset * * @param data the data to use * @param shape the shape of the ndarray */ public BaseNDArray(float[] data, int[] shape, char ordering) { this(data, shape, 0, ordering); } /** * @param data the data to use * @param shape the shape of the ndarray * @param offset the desired offset * @param ordering the ordering of the ndarray */ public BaseNDArray(float[] data, int[] shape, long offset, char ordering) { this(data, shape, Nd4j.getStrides(shape, ordering), offset); } public BaseNDArray(double[] data, long[] shape, long offset, char ordering) { this(data, shape, Nd4j.getStrides(shape, ordering), offset); } public BaseNDArray(float[] data, long[] shape, long offset, char ordering) { this(data, shape, Nd4j.getStrides(shape, ordering), offset); } /** * Construct an ndarray of the specified shape * with an empty data array * * @param shape the shape of the ndarray * @param stride the stride of the ndarray * @param offset the desired offset * @param ordering the ordering of the ndarray */ public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) { this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); } public BaseNDArray(long[] shape, long[] stride, long offset, char ordering) { this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); } /** * Construct an ndarray of the specified shape. * * @param shape the shape of the ndarray * @param stride the stride of the ndarray * @param offset the desired offset * @param ordering the ordering of the ndarray * @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't. */ public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) { this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); } public BaseNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) { this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); } public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize) { this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), type, shape, stride, offset, ordering); } public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize, MemoryWorkspace workspace) { this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize, workspace), type, shape, stride, offset, ordering); } /** * Create the ndarray with * the specified shape and stride and an offset of 0 * * @param shape the shape of the ndarray * @param stride the stride of the ndarray * @param ordering the ordering of the ndarray */ public BaseNDArray(int[] shape, int[] stride, char ordering) { this(shape, stride, 0, ordering); } /** * * @param shape * @param offset * @param ordering */ public BaseNDArray(int[] shape, long offset, char ordering) { this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); } public BaseNDArray(long[] shape, long offset, char ordering) { this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); } /** * Create an ndarray * with the given shape * @param shape */ public BaseNDArray(int[] shape) { this(shape, 0, Nd4j.order()); } public BaseNDArray(long[] shape) { this(shape, 0, Nd4j.order()); } /** * Creates a new n times m DoubleMatrix. * * @param newRows the number of rows (n) of the new matrix. * @param newColumns the number of columns (m) of the new matrix. */ public BaseNDArray(int newRows, int newColumns, char ordering) { Shape.assertValidOrder(ordering); this.data = Nd4j.createBuffer((long) newRows * newColumns); val shape = new long[] {newRows, newColumns}; val stride = Nd4j.getStrides(shape, ordering); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); init(shape, stride); } public BaseNDArray(long newRows, long newColumns, char ordering) { Shape.assertValidOrder(ordering); this.data = Nd4j.createBuffer((long) newRows * newColumns); long[] shape = new long[] {newRows, newColumns}; long[] stride = Nd4j.getStrides(shape, ordering); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); init(shape, stride); } /** * Create an ndarray from the specified slices. * This will go through and merge all of the * data from each slice in to one ndarray * which will then take the specified shape * * @param slices the slices to merge * @param shape the shape of the ndarray */ public BaseNDArray(List slices, int[] shape, char ordering) { this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); } public BaseNDArray(List slices, long[] shape, char ordering) { this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); } /** * Create an ndarray from the specified slices. * This will go through and merge all of the * data from each slice in to one ndarray * which will then take the specified shape * * @param slices the slices to merge * @param shape the shape of the ndarray */ public BaseNDArray(List slices, int[] shape, int[] stride, char ordering) { Shape.assertValidOrder(ordering); DataBuffer ret = slices.get(0).data().dataType() == (DataType.FLOAT) ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); this.data = ret; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), false)); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); if (slices.get(0).isScalar()) { for (int i = 0; i < length(); i++) { putScalar(i, slices.get(i).getDouble(0)); } } else { for (int i = 0; i < slices(); i++) { putSlice(i, slices.get(i)); } } } public BaseNDArray(List slices, long[] shape, long[] stride, char ordering) { DataBuffer ret = Nd4j.createBuffer(slices.get(0).dataType(), Shape.lengthOf(shape), false); /*slices.get(0).data().dataType() == (DataType.FLOAT) ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); */ this.data = ret; setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), false)); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); if (slices.get(0).isScalar()) { for (int i = 0; i < length(); i++) { putScalar(i, slices.get(i).getDouble(0)); } } else { for (int i = 0; i < slices(); i++) { putSlice(i, slices.get(i)); } } } /** * * @param data * @param shape * @param stride * @param ordering */ public BaseNDArray(float[] data, int[] shape, int[] stride, char ordering) { this(data, shape, stride, 0, ordering); } /** * * @param data * @param shape * @param stride * @param offset * @param ordering */ public BaseNDArray(float[] data, int[] shape, int[] stride, long offset, char ordering) { Shape.assertValidOrder(ordering); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data != null && data.length > 0 ? false : true)); if (data != null && data.length > 0) { val perfD = PerformanceTracker.getInstance().helperStartTransaction(); this.data = internalCreateBuffer(data, offset); PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, data.length * Nd4j.sizeOfDataType(DataType.FLOAT), MemcpyDirection.HOST_TO_HOST); if (offset >= data.length) throw new IllegalArgumentException("invalid offset: must be < data.length"); } init(shape, stride); } public BaseNDArray(float[] data, long[] shape, long[] stride, long offset, char ordering) { Shape.assertValidOrder(ordering); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data != null && data.length > 0 ? false : true)); if (data != null && data.length > 0) { this.data = Nd4j.createTypedBuffer(data, DataType.FLOAT); if (offset >= data.length) throw new IllegalArgumentException("invalid offset: must be < data.length"); } init(shape, stride); } public BaseNDArray(double[] data, long[] shape, long[] stride, long offset, char ordering) { Shape.assertValidOrder(ordering); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.DOUBLE, data != null && data.length > 0 ? false : true)); if (data != null && data.length > 0) { this.data = Nd4j.createBuffer(data, offset); if (offset >= data.length) throw new IllegalArgumentException("invalid offset: must be < data.length"); } init(shape, stride); } /** * * @param data * @param shape * @param stride * @param offset */ public BaseNDArray(DataBuffer data, int[] shape, int[] stride, long offset) { this.data = Nd4j.createBuffer(data, offset, ArrayUtil.prodLong(shape)); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f'), Nd4j.order(), data.dataType(), false)); init(shape, stride); // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f')); } /** * * @param data * @param shape * @param strides */ public BaseNDArray(int[] data, int[] shape, int[] strides) { this(internalCreateBuffer(data), shape, strides); } /** * * @param data * @param shape */ public BaseNDArray(DataBuffer data, int[] shape) { this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); } public BaseNDArray(DataBuffer data, long[] shape) { this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); } /** * * @param buffer * @param shape * @param offset */ public BaseNDArray(DataBuffer buffer, int[] shape, long offset) { this(Nd4j.createBuffer(buffer, offset, ArrayUtil.prodLong(shape)), shape, Nd4j.getStrides(shape), offset, Nd4j.order()); } /** * * @param buffer * @param shape * @param ordering */ public BaseNDArray(DataBuffer buffer, int[] shape, char ordering) { this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); } public BaseNDArray(DataBuffer buffer, long[] shape, char ordering) { this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); } /** * * @param data * @param shape * @param ordering */ public BaseNDArray(double[] data, int[] shape, char ordering) { this(Nd4j.createBuffer(data), shape, ordering); } public BaseNDArray(double[] data, long[] shape, char ordering) { this(Nd4j.createBuffer(data), shape, ordering); } public BaseNDArray(float[] data, long[] shape, char ordering) { this(Nd4j.createBuffer(data), shape, ordering); } /** * * @param data * @param shape * @param stride * @param offset * @param ordering */ public BaseNDArray(double[] data, int[] shape, int[] stride, long offset, char ordering) { this(internalCreateBuffer(data, offset), shape, stride, offset, ordering); } /** * * @param data * @param order */ public BaseNDArray(float[] data, char order) { this(internalCreateBuffer(data), order); } protected static DataBuffer internalCreateBuffer(float[] data) { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data); PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } protected static DataBuffer internalCreateBuffer(double[] data) { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data); PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } protected static DataBuffer internalCreateBuffer(int[] data) { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data); PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } protected static DataBuffer internalCreateBuffer(float[] data, long offset) { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data, offset); PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } protected static DataBuffer internalCreateBuffer(double[] data, long offset) { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data, offset); PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } /** * * @param floatBuffer * @param order */ public BaseNDArray(DataBuffer floatBuffer, char order) { this(floatBuffer, new int[] {(int) floatBuffer.length()}, Nd4j.getStrides(new int[] {(int) floatBuffer.length()}, order), 0, order); Shape.assertValidOrder(order); if (floatBuffer.length() >= Integer.MAX_VALUE) throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); } /** * * @param buffer * @param shape * @param strides */ public BaseNDArray(DataBuffer buffer, int[] shape, int[] strides) { this(buffer, shape, strides, 0, Nd4j.order()); } /** * Create this ndarray with the given data and shape and 0 offset * * @param data the data to use * @param shape the shape of the ndarray */ public BaseNDArray(float[] data, int[] shape) { this(data, shape, 0); } /** * * @param data * @param shape * @param offset */ public BaseNDArray(float[] data, int[] shape, long offset) { this(data, shape, offset, Nd4j.order()); } /** * Construct an ndarray of the specified shape * with an empty data array * * @param shape the shape of the ndarray * @param stride the stride of the ndarray * @param offset the desired offset */ public BaseNDArray(int[] shape, int[] stride, long offset) { this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); } public BaseNDArray(long[] shape, long[] stride, long offset) { this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); } /** * Create the ndarray with * the specified shape and stride and an offset of 0 * * @param shape the shape of the ndarray * @param stride the stride of the ndarray */ public BaseNDArray(int[] shape, int[] stride) { this(shape, stride, 0); } /** * * @param shape * @param offset */ public BaseNDArray(int[] shape, long offset) { this(shape, Nd4j.getStrides(shape), offset); } /** * * @param shape * @param ordering */ public BaseNDArray(int[] shape, char ordering) { this(shape, 0, ordering); } /** * Creates a new n times m DoubleMatrix. * * @param newRows the number of rows (n) of the new matrix. * @param newColumns the number of columns (m) of the new matrix. */ public BaseNDArray(int newRows, int newColumns) { this(newRows, newColumns, Nd4j.order()); } public BaseNDArray(long newRows, long newColumns) { this(newRows, newColumns, Nd4j.order()); } /** * Create an ndarray from the specified slices. * This will go through and merge all of the * data from each slice in to one ndarray * which will then take the specified shape * * @param slices the slices to merge * @param shape the shape of the ndarray */ public BaseNDArray(List slices, int[] shape) { this(slices, shape, Nd4j.order()); } public BaseNDArray(List slices, long[] shape) { this(slices, shape, Nd4j.order()); } /** * Create an ndarray from the specified slices. * This will go through and merge all of the * data from each slice in to one ndarray * which will then take the specified shape * * @param slices the slices to merge * @param shape the shape of the ndarray */ public BaseNDArray(List slices, int[] shape, int[] stride) { this(slices, shape, stride, Nd4j.order()); } public BaseNDArray(List slices, long[] shape, long[] stride) { this(slices, shape, stride, Nd4j.order()); } /** * * @param data * @param shape * @param stride */ public BaseNDArray(float[] data, int[] shape, int[] stride) { this(data, shape, stride, Nd4j.order()); } /** * * @param data * @param shape * @param stride * @param offset */ public BaseNDArray(float[] data, int[] shape, int[] stride, long offset) { this(data, shape, stride, offset, Nd4j.order()); } public BaseNDArray(double[] data, long[] shape, long[] stride, long offset) { this(data, shape, stride, offset, Nd4j.order()); } public BaseNDArray(float[] data, long[] shape, long[] stride, long offset) { this(data, shape, stride, offset, Nd4j.order()); } /** * * @param data */ public BaseNDArray(float[] data) { this(Nd4j.createBuffer(data)); } /** * Initialize the ndarray * with the given data * @param data */ public BaseNDArray(float[][] data) { this(data, Nd4j.order()); } /** * * @param data * @param ordering */ public BaseNDArray(float[][] data, char ordering) { this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), new int[] {data.length, data[0].length}, Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering); int c = columns(); for (int r = 0; r < rows(); r++) { Preconditions.checkState(data[r].length == c, "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c ); } } /** * Constructor for stride and offset * * @param buffer * @param shape * @param offset * @param ordering */ public BaseNDArray(DataBuffer buffer, int[] shape, long offset, char ordering) { this(buffer, shape, Nd4j.getStrides(shape, ordering), offset, ordering); } public BaseNDArray(double[] data, int[] shape, int[] stride, long offset) { this(data, shape, stride, offset, Nd4j.order()); } /** * Returns whether the ndarray is valid or not * @return true if the ndarray is valid * false otherwise */ @Deprecated public boolean isValid() { try { linearIndex(length() - 1); } catch (Exception e) { return false; } return true; } protected INDArray create(DataBuffer data, int[] shape, long offset) { return Nd4j.create(data, shape, offset); } @Override public int elementWiseStride() { return Shape.elementWiseStride(shapeInfoDataBuffer()); } @Override public long tensorsAlongDimension(int... dimension) { if (dimension == null || dimension.length == 0) throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) return 1; for (int i = 0; i < dimension.length; i++) if (dimension[i] < 0) dimension[i] += rank(); long[] tensorShape = ArrayUtil.keep(shape(), dimension); long len = ArrayUtil.prodLong(tensorShape); if (len == 0) throw new IllegalStateException("Illegal length found after removing index"); long length = length(); if (length / len >= Integer.MAX_VALUE) throw new IllegalArgumentException("Tensors along dimension can not be >= Integer.MAX_VALUE"); return length / len; } @Override public INDArray tensorAlongDimension(long index, int... dimension) { if (dimension == null || dimension.length == 0) throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); Preconditions.checkArgument(!this.isEmpty(), "tensorAlongDimension(...) can't be used on empty tensors"); if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) return this; for (int i = 0; i < dimension.length; i++) if (dimension[i] < 0) dimension[i] += rank(); //dedup if (dimension.length > 1) dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension)))); if (dimension.length > 1) { Arrays.sort(dimension); } long tads = tensorsAlongDimension(dimension); if (index >= tads) throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); if (dimension.length == 1) { if (dimension[0] == 0 && isColumnVector()) { return this.transpose(); } else if (dimension[0] == 1 && isRowVector()) { return this; } } Pair tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension); DataBuffer shapeInfo = tadInfo.getFirst(); val jShapeInfo = shapeInfo.asLong(); val shape = Shape.shape(jShapeInfo); val stride = Shape.stride(jShapeInfo); long offset = offset() + tadInfo.getSecond().getLong(index); val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2); char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3); val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder); return toTad; } private void setShapeInformation(Pair shapeInfo) { this.shapeInformation = shapeInfo.getFirst(); this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond()); } private INDArray doTad(int index, int... dimension) { if (dimension == null || dimension.length == 0) throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); if (dimension.length >= rank()) return this; for (int i = 0; i < dimension.length; i++) if (dimension[i] < 0) dimension[i] += rank(); if (dimension.length > 1) Arrays.sort(dimension); long tads = tensorsAlongDimension(dimension); if (index >= tads) throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); if (dimension.length == 1) { if (dimension[0] == 0 && isColumnVector()) { return this.transpose(); } else if (dimension[0] == 1 && isRowVector()) { return this; } } long[] tensorShape = ArrayUtil.keep(shape(), dimension); int[] reverseDimensions = ArrayUtil.reverseCopy(dimension); int[] remove = ArrayUtil.removeIndex(ArrayUtil.range(0, rank()), dimension); int[] newPermuteDims = Ints.concat(remove, reverseDimensions); int[] finalPermuteDims = tadFinalPermuteDimensions[dimension.length]; INDArray permuted = permute(newPermuteDims); long sliceIdx = NDArrayMath.sliceOffsetForTensor(index, permuted, tensorShape); INDArray ret2 = permuted.slice(sliceIdx); if (dimension.length == tensorShape.length && ArrayUtil.prodLong(tensorShape) == ret2.length()) { if (dimension.length == 1 && ret2.isRowVector()) return ret2; if (finalPermuteDims.length != ret2.rank()) { finalPermuteDims = new int[ret2.rank()]; int count = 0; for (int i = finalPermuteDims.length - 1; i >= 0; i--) finalPermuteDims[count++] = i; } return ret2.permutei(finalPermuteDims); } int length = ArrayUtil.prod(tensorShape); int tensorLength = ArrayUtil.prod(tensorShape); long offset = index * tensorLength / NDArrayMath.lengthPerSlice(ret2); if (sliceIdx == 0 && length == NDArrayMath.lengthPerSlice(ret2)) { if (offset > Integer.MAX_VALUE) throw new ND4JArraySizeException(); ret2 = ret2.slice((int) offset); if (dimension.length == 1 && ret2.isRowVectorOrScalar()) return ret2; return ret2.permutei(finalPermuteDims); } else if (length == NDArrayMath.lengthPerSlice(ret2)) { offset -= ret2.slices() * (offset / ret2.slices()); if (offset > Integer.MAX_VALUE) throw new ND4JArraySizeException(); ret2 = ret2.slice((int) offset); if (dimension.length == 1 && ret2.isRowVectorOrScalar()) return ret2; return ret2.permutei(finalPermuteDims); } while (ret2.length() > length) { sliceIdx = NDArrayMath.sliceOffsetForTensor(index, ret2, tensorShape); sliceIdx -= ret2.slices() * (sliceIdx / ret2.slices()); ret2 = ret2.slice(sliceIdx); } if (dimension.length == 1 && ret2.isRowVectorOrScalar()) return ret2; return ret2.permutei(finalPermuteDims); } @Override public long vectorsAlongDimension(int dimension) { if (dimension == 0 && isVector() || isRowVectorOrScalar()) return 1; if (size(dimension) == 1 && !isVector()) { for (int i = dimension; i < rank(); i++) { if (size(i) != 1) return vectorsAlongDimension(i); } return length(); } else if (size(0) == 1 && !isVectorOrScalar()) { int realDimension = rank() - getLeadingOnes(); long length = length(); if (length / size(realDimension) >= Integer.MAX_VALUE) throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); return length / size(realDimension); } long length = length(); if (dimension >= jvmShapeInfo.rank) { if (length / size(jvmShapeInfo.rank - 1) >= Integer.MAX_VALUE) throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); return (int) (length / size(jvmShapeInfo.rank - 1)); } if (length / size(dimension) >= Integer.MAX_VALUE) throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); return length / size(dimension); } @Override public INDArray vectorAlongDimension(int index, int dimension) { if (dimension < 0) { dimension = jvmShapeInfo.getRank() + dimension; } //return the whole thing if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2 || rank() > 2 && dimension == 0 && size(dimension) == 1) { return this; } return tensorAlongDimension(index, dimension); } @Override public void setOrder(char order) { setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty())); } @Override public void setShapeAndStride(int[] shape, int[] stride) { setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false)); } @Override public INDArray cumsumi(int dimension) { validateNumericalArray("cumsumi", true); if(isScalar() || isEmpty()) return this; if (isVector()) { double s = 0.0; for (int i = 0; i < length(); i++) { s += getDouble(i); putScalar(i, s); } } else if (dimension == Integer.MAX_VALUE) { INDArray flattened = ravel(); double prevVal = flattened.getDouble(0); for (int i = 1; i < flattened.length(); i++) { double d = prevVal + flattened.getDouble(i); flattened.putScalar(i, d); prevVal = d; } return flattened; } else { for (int i = 0; i < vectorsAlongDimension(dimension); i++) { INDArray vec = vectorAlongDimension(i, dimension); vec.cumsumi(0); } } return this; } @Override public Number normmaxNumber() { return normmax(Integer.MAX_VALUE).getDouble(0); } @Override public Number norm2Number() { return norm2(Integer.MAX_VALUE).getDouble(0); } @Override public Number norm1Number() { return norm1(Integer.MAX_VALUE).getDouble(0); } @Override public Number stdNumber() { return std(Integer.MAX_VALUE).getDouble(0); } @Override public Number prodNumber() { if(isScalar()) return getNumber(0); return prod(Integer.MAX_VALUE).getDouble(0); } @Override public Number meanNumber() { validateNumericalArray("meanNumber", false); if(isScalar()) return getNumber(0); return mean(Integer.MAX_VALUE).getDouble(0); } @Override public Number ameanNumber() { return amean(Integer.MAX_VALUE).getDouble(0); } @Override public Number varNumber() { return var(Integer.MAX_VALUE).getDouble(0); } @Override public Number maxNumber() { if(isScalar()) return getNumber(0); return max(Integer.MAX_VALUE).getDouble(0); } @Override public Number amaxNumber() { return amax(Integer.MAX_VALUE).getDouble(0); } @Override public Number minNumber() { if(isScalar()) return getNumber(0); return min(Integer.MAX_VALUE).getDouble(0); } @Override public Number aminNumber() { return amin(Integer.MAX_VALUE).getDouble(0); } @Override public Number scan(Condition condition) { MatchCondition op = new MatchCondition(this, condition); return Nd4j.getExecutioner().exec(op).getDouble(0); } @Override public Number sumNumber() { validateNumericalArray("sum", false); if(isScalar()) return getNumber(0); val scalar = sum(Integer.MAX_VALUE); Nd4j.getExecutioner().commit(); return scalar.getDouble(0); } @Override public Number entropyNumber() { return entropy(Integer.MAX_VALUE).getDouble(0); } @Override public Number shannonEntropyNumber() { return shannonEntropy(Integer.MAX_VALUE).getDouble(0); } @Override public Number logEntropyNumber() { return logEntropy(Integer.MAX_VALUE).getDouble(0); } @Override public INDArray cumsum(int dimension) { validateNumericalArray("cumsum", true); return dup().cumsumi(dimension); } @Override public INDArray assign(final INDArray arr) { Preconditions.checkState((this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()), "Cannot assign arrays: arrays must both be scalars, both vectors, or shapes must be equal other than size 1 dimensions. Attempting to do x.assign(y)" + " with x.shape=%ndShape and y.shape=%ndShape", this, arr ); Preconditions.checkArgument(this.length() == arr.length(), "Length of both arrays must be equal"); Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this)); return this; } @Override public INDArray putScalar(long i, double value) { Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); if (i < 0) i += rank(); // TODO: i'm not sure that rank == 1 has fair shortcut here if (isScalar()) { autoProcessScalarCall(); data.put(i, value); return this; } else if (rank() == 1) { data.put(i * stride(0), value); return this; } // we cant raise rank here, if original rank is 1 if (isRowVector() && rank() == 2) { return putScalar(0, i, value); } else if (isColumnVector() && rank() == 2) { return putScalar(i, 0, value); } long[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); return putScalar(indexes, value); } @Override public INDArray putScalar(long i, float value) { return putScalar(i, (double) value); } @Override public INDArray putScalar(long i, int value) { return putScalar(i, (double) value); } @Override public INDArray putScalar(int[] indexes, double value) { Nd4j.getCompressor().autoDecompress(this); Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); for (int i = 0; i < indexes.length; i++) { if (indexes[i] < 0) indexes[i] += this.size(i); } if (indexes.length == 1) { return putScalar(indexes[0], value); } else if (indexes.length == 2) { return putScalar(indexes[0], indexes[1], value); } else if (indexes.length == 3) { return putScalar(indexes[0], indexes[1], indexes[2], value); } else if (indexes.length == 4) { return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); } else { autoProcessScalarCall(); long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); data.put(offset, value); } return this; } @Override public INDArray putScalar(long[] indexes, double value) { Nd4j.getCompressor().autoDecompress(this); Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); for (int i = 0; i < indexes.length; i++) { if (indexes[i] < 0) indexes[i] += size(i); } if (indexes.length == 1) { return putScalar(indexes[0], value); } else if (indexes.length == 2) { return putScalar(indexes[0], indexes[1], value); } else if (indexes.length == 3) { return putScalar(indexes[0], indexes[1], indexes[2], value); } else if (indexes.length == 4) { return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); } else { autoProcessScalarCall(); long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); data.put(offset, value); } return this; } @Override public INDArray putScalar(long[] indexes, float value) { return putScalar(indexes, (double) value); } @Override public INDArray putScalar(long row, long col, double value) { Nd4j.getCompressor().autoDecompress(this); autoProcessScalarCall(); Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); if (rank() > 2) throw new IllegalStateException("Cannot use putScalar(int,int,double) on a rank " + rank() + " INDArray"); long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, row, col); data.put(offset, value); return this; } @Override public INDArray putScalar(long dim0, long dim1, long dim2, double value) { Nd4j.getCompressor().autoDecompress(this); autoProcessScalarCall(); Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); if (rank() != 3) throw new IllegalStateException( "Cannot use putScalar(int,int,int,double) on a rank " + rank() + " INDArray"); long offset = 0; // Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2); long size_0 = jvmShapeInfo.javaShapeInformation[1]; long size_1 = jvmShapeInfo.javaShapeInformation[1 + 1]; long size_2 = jvmShapeInfo.javaShapeInformation[1 + 2]; if (size_0 != 1) offset += dim0 * jvmShapeInfo.javaShapeInformation[1 + 0 + 3]; if (size_1 != 1) offset += dim1 * jvmShapeInfo.javaShapeInformation[1 + 1 + 3]; if (size_2 != 1) offset += dim2 * jvmShapeInfo.javaShapeInformation[1 + 2 + 3]; data.put(offset, value); return this; } @Override public INDArray putScalar(long dim0, long dim1, long dim2, long dim3, double value) { Nd4j.getCompressor().autoDecompress(this); autoProcessScalarCall(); Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); if (rank() != 4) throw new IllegalStateException( "Cannot use putScalar(int,int,int,int,double) on a rank " + rank() + " INDArray"); long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, dim0, dim1, dim2, dim3); data.put(offset, value); return this; } @Override public INDArray putScalar(int[] indexes, float value) { return putScalar(indexes, (double) value); } @Override public INDArray putScalar(int[] indexes, int value) { return putScalar(indexes, (double) value); } @Override public INDArray putScalar(long[] indexes, int value) { return putScalar(indexes, (double) value); } @Override public INDArray eps(Number other) { validateNumericalArray("eps", true); return Nd4j.getExecutioner().exec(new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); } @Override public INDArray eps(INDArray other) { validateNumericalArray("eps", true); return Nd4j.getExecutioner().exec(new Eps(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()))); } @Override public INDArray lt(Number other) { validateNumericalArray("less than (lt)", false); return Nd4j.getExecutioner().exec(new ScalarLessThan(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); } @Override public INDArray lte(Number other) { validateNumericalArray("less than or equals (lte)", false); return Nd4j.getExecutioner().exec(new ScalarLessThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); } @Override public INDArray eq(Number other) { Preconditions.checkArgument(dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, "Scalar equality on boolean arrays can only be applied with values 0 or 1: got value %s",other); return Nd4j.getExecutioner().exec(new ScalarEquals(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); } @Override public INDArray gt(Number other) { validateNumericalArray("greater than (gt)", false); return Nd4j.getExecutioner().exec(new ScalarGreaterThan(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); } @Override public INDArray gte(Number other) { validateNumericalArray("greater than or equals (gte)", false); return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); } @Override public INDArray lt(INDArray other) { validateNumericalArray("less than (lt)", false); if (Shape.shapeEquals(this.shape(), other.shape())) { return Nd4j.getExecutioner().exec(new LessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; } else throw new IllegalArgumentException("Shapes must be broadcastable"); } @Override public INDArray neq(Number other) { Preconditions.checkArgument(dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, "Scalar non-equality on boolean arrays can only be applied with values 0 or 1: got value %s",other); Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); return Nd4j.getExecutioner().exec(new ScalarNotEquals(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); } @Override public INDArray neq(INDArray other) { Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); return Nd4j.getExecutioner().exec(new NotEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; } @Override public INDArray eq(INDArray other) { if (Shape.shapeEquals(this.shape(), other.shape())) { return Nd4j.getExecutioner().exec(new EqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; } else throw new IllegalArgumentException("Shapes must be broadcastable"); } @Override public INDArray gt(INDArray other) { validateNumericalArray("greater than (gt)", false); if (Shape.shapeEquals(this.shape(), other.shape())) { return Nd4j.getExecutioner().exec(new GreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; } else throw new IllegalArgumentException("Shapes must be broadcastable"); } @Override public INDArray isInfinite(){ validateNumericalArray("isInfinite", true); if(isEmpty()) return Nd4j.empty(DataType.BOOL); return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isInfinite())); } @Override public INDArray isNaN(){ validateNumericalArray("isNaN", true); if(isEmpty()) return Nd4j.empty(DataType.BOOL); return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isNan())); } @Override public INDArray neg() { validateNumericalArray("negative (neg)", true); if(isEmpty()) return this; return Nd4j.getExecutioner().exec(new Negative(this, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()))); } @Override public INDArray negi() { validateNumericalArray("negative (negi)", true); if(isEmpty()) return this; Nd4j.getExecutioner().exec(new Negative(this)); return this; } @Override public INDArray rdiv(Number n, INDArray result) { return rdivi(n, result); } @Override public INDArray rdivi(Number n, INDArray result) { validateNumericalArray("rdivi", false); if (Double.isNaN(n.doubleValue())) n = Nd4j.EPS_THRESHOLD; Nd4j.getExecutioner().exec(new ScalarReverseDivision(this, null, result, n)); return result; } @Override public INDArray rsub(Number n, INDArray result) { return rsubi(n, result); } @Override public INDArray rsubi(Number n, INDArray result) { validateNumericalArray("rsubi", false); if (Double.isNaN(n.doubleValue())) n = Nd4j.EPS_THRESHOLD; Nd4j.getExecutioner().exec(new ScalarReverseSubtraction(this, result, n)); return result; } @Override public INDArray div(Number n, INDArray result) { return divi(n, result); } @Override public INDArray divi(Number n, INDArray result) { validateNumericalArray("divi", false); if (Double.isNaN(n.doubleValue())) n = Nd4j.EPS_THRESHOLD; Nd4j.getExecutioner().exec(new ScalarDivision(this, null, result, n)); return result; } @Override public INDArray mul(Number n, INDArray result) { return muli(n, result); } @Override public INDArray muli(Number n, INDArray result) { validateNumericalArray("muli", false); if (Double.isNaN(n.doubleValue())) n = Nd4j.EPS_THRESHOLD; Nd4j.getExecutioner().exec(new ScalarMultiplication(this, null, result, n)); return result; } @Override public INDArray sub(Number n, INDArray result) { return subi(n, result); } @Override public INDArray subi(Number n, INDArray result) { validateNumericalArray("subi", false); if (Double.isNaN(n.doubleValue())) n = Nd4j.EPS_THRESHOLD; Nd4j.getExecutioner().exec(new ScalarSubtraction(this, null, result, n)); return result; } @Override public INDArray add(Number n, INDArray result) { return addi(n, result); } @Override public INDArray addi(Number n, INDArray result) { validateNumericalArray("addi", false); if (Double.isNaN(n.doubleValue())) n = Nd4j.EPS_THRESHOLD; Nd4j.getExecutioner().exec(new ScalarAdd(this, null, result, n)); return result; } @Override public INDArray getScalar(long row, long column) { return getScalar(new long[] {row, column}); } @Override public INDArray dup() { return dup(Nd4j.order()); } @Override public INDArray dup(char order) { WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray"); if (this.isCompressed() && this.ordering() == order) { INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer()); ret.markAsCompressed(true); return ret; } if(isEmpty()) return this; Nd4j.getCompressor().autoDecompress(this); // fixme: eventually it would be nice to have this in native code if (isS()) { val list = new ArrayList(); for (int e = 0; e < this.length(); e++) list.add(this.getString(e)); return Nd4j.create(list, this.shape(), this.ordering()); } val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order); z.assign(this); return z; } @Override public int getInt(int... indices) { return (int) getDouble(indices); } @Override public long getLong(long index) { Nd4j.getCompressor().autoDecompress(this); Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); if (index >= length()) { throw new IllegalArgumentException("Unable to get linear index " + index + ": values is greater than length (" + length() + ")"); } autoProcessScalarCall(); if (index == 0) return data().getLong(index); long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, index) : Shape.ind2sub(this, index); Shape.assertShapeLessThan(dimensions, shape()); return getLong(dimensions); } @Override public long getLong(long... indices) { if(isScalar()) return data().getLong(0); return Shape.getLong(this, indices); } @Override public double getDouble(int... indices) { autoProcessScalarCall(); Nd4j.getCompressor().autoDecompress(this); Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); for (int i = 0; i < indices.length; i++) { if (indices[i] < 0) indices[i] += rank(); } if (indices.length == 1) { if (rank() == 1) return Shape.getDouble(this, indices[0]); else if (isRowVector()) return Shape.getDouble(this, 0, indices[0]); else if (isColumnVector()) return Shape.getDouble(this, indices[0], 0); else if ((isScalar() || length() == 1) && indices[0] == 0) return data().getDouble(0); } return Shape.getDouble(this, indices); } @Override public double getDouble(long... indices) { autoProcessScalarCall(); Nd4j.getCompressor().autoDecompress(this); Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); for (int i = 0; i < indices.length; i++) { if (indices[i] < 0) indices[i] += rank(); } if (indices.length == 1) { if (rank() == 1) return Shape.getDouble(this, indices[0]); else if (isRowVector()) return Shape.getDouble(this, 0, indices[0]); else if (isColumnVector()) return Shape.getDouble(this, indices[0], 0); else if (isScalar() && indices[0] == 0) return data().getDouble(0); else throw new IllegalStateException("Indexes length must be > 1 for non vectors and scalars"); } return Shape.getDouble(this, indices); } @Override public float getFloat(int... indices) { return (float) getDouble(indices); } @Override public float getFloat(long... indices) { return (float) getDouble(indices); } @Override public boolean isScalar() { if (isEmpty()) return false; if (jvmShapeInfo.rank == 0) { return true; } else if (jvmShapeInfo.rank > 2) { return false; } else if (jvmShapeInfo.rank == 1) { return shape()[0] == 1; } else if (jvmShapeInfo.rank == 2) { return shape()[0] == 1 && shape()[1] == 1 || length() == 1; } else return false; } @Override public INDArray put(int[] indices, INDArray element) { Nd4j.getCompressor().autoDecompress(this); if (!element.isScalar()) throw new IllegalArgumentException("Unable to insert anything but a scalar"); if (isRowVector() && indices[0] == 0 && indices.length == 2) { int ix = 0; for (int i = 1; i < indices.length; i++) ix += indices[i] * stride(i); if (ix >= data.length()) throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); data.put(ix, element.getDouble(0)); } else { int ix = 0; for (int i = 0; i < indices.length; i++) if (size(i) != 1) ix += indices[i] * stride(i); if (ix >= data.length()) throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); data.put(ix, element.getDouble(0)); } return this; } @Override public INDArray match(INDArray comp, Condition condition) { // TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal"); Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal"); return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition)); } @Override public INDArray match(Number comp, Condition condition) { return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp.doubleValue(), condition)); } @Override public INDArray getWhere(INDArray comp, Condition condition) { return BooleanIndexing.chooseFrom(new INDArray[]{this,comp},condition); } @Override public INDArray getWhere(Number comp, Condition condition) { return BooleanIndexing.chooseFrom(new INDArray[]{this},Arrays.asList(comp.doubleValue()),Collections.emptyList(),condition); } @Override public INDArray putWhere(INDArray comp, INDArray put, Condition condition) { Nd4j.getCompressor().autoDecompress(this); MatchConditionTransform matchCondition = new MatchConditionTransform(this,comp,condition); Nd4j.getExecutioner().exec(matchCondition); return putWhereWithMask(matchCondition.z(),put); } @Override public INDArray putWhere(Number comp, INDArray put, Condition condition) { return putWhere(Nd4j.scalar(comp),put,condition); } @Override public INDArray putWhere(Number comp, Number put, Condition condition) { return putWhere(Nd4j.scalar(comp),Nd4j.scalar(put),condition); } @Override public INDArray putWhereWithMask(INDArray mask, INDArray put) { INDArray output = dup(); Nd4j.getExecutioner().execAndReturn(new Where(new INDArray[]{mask,this,put},new INDArray[]{output})); return output; } @Override public INDArray putWhereWithMask(INDArray mask, Number put) { return putWhereWithMask(mask,Nd4j.scalar(put)); } @Override public INDArray put(int i, int j, INDArray element) { return put(new int[] {i, j}, element); } @Override public INDArray put(int i, int j, Number element) { return putScalar(new int[] {i, j}, element.doubleValue()); } @Override public INDArray putSlice(int slice, INDArray put) { Nd4j.getCompressor().autoDecompress(this); if (isScalar()) { Preconditions.checkState(put.isScalar(), "Invalid dimension. Can only insert a scalar in to another scalar"); put(0, put.getScalar(0)); return this; } else if (isVector()) { Preconditions.checkState(put.isVectorOrScalar() && put.length() == length(), "Invalid dimension on insertion. Can only insert scalars/vectors into other scalar/vectors"); if (put.isScalar()) putScalar(slice, put.getDouble(0)); else for (int i = 0; i < length(); i++) putScalar(i, put.getDouble(i)); return this; } assertSlice(put, slice); INDArray view = slice(slice); if (put.length() == 1) { putScalar(slice, put.getDouble(0)); } else { if(!(view.isVector() && put.isVector() && view.length() == put.length()) && !view.equalShapes(put)){ throw new IllegalStateException("Cannot put slice: array to be put (" + Arrays.toString(put.shape()) + ") and slice array (" + Arrays.toString(view.shape()) + ") have different shapes"); } view.assign(put); } return this; } protected void assertSlice(INDArray put, long slice) { Preconditions.checkArgument(slice < slices(), "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", slice, slices()); long[] sliceShape = put.shape(); if (Shape.isRowVectorShape(sliceShape)) { return; } else { long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); //no need to compare for scalar; primarily due to shapes either being [1] or length 0 if (put.isScalar()) return; if (isVector() && put.isVector() && put.length() < length()) return; //edge case for column vectors if (Shape.isColumnVectorShape(sliceShape)) return; if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape) && !Shape.isRowVectorShape(sliceShape)) throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ", Arrays.toString(sliceShape), Arrays.toString(requiredShape))); } } public boolean isMatrix() { return rank() == 2; } protected INDArray newShape(long[] newShape, char ordering) { return Nd4j.create(data(), newShape, stride(), 0, ordering); } protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset, char ordering) { return Nd4j.create(data, newShape, newStrides, offset, ordering); } protected INDArray create(DataBuffer data, long[] newShape, long[] newStrides, long offset, char ordering) { return Nd4j.create(data, newShape, newStrides, offset, ordering); } protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset) { return Nd4j.create(data, newShape, newStrides, offset); } protected INDArray create(int[] shape) { return Nd4j.create(shape, getStrides(shape, Nd4j.order()), 0); } protected INDArray create(int[] shape, int[] strides, long offset) { return Nd4j.create(shape, strides, offset); } protected int[] getStrides(int[] shape, char ordering) { return Nd4j.getStrides(shape, ordering); } @Override public double squaredDistance(INDArray other) { validateNumericalArray("squaredDistance", false); double d2 = distance2(other); return d2 * d2; } @Override public double distance2(INDArray other) { validateNumericalArray("distance2", false); Nd4j.getCompressor().autoDecompress(this); return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue(); } @Override public double distance1(INDArray other) { validateNumericalArray("distance1", false); Nd4j.getCompressor().autoDecompress(this); return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue(); } @Override public INDArray get(INDArray indices) { if(indices.rank() > 2) { throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); } if (rank() == 1) { Preconditions.checkArgument(indices.rank() <= 1, "For 1D vector indices must be either scalar or vector as well"); val ret = Nd4j.createUninitialized(this.dataType(), indices.length()); for (int e = 0; e < indices.length(); e++) { val idx = indices.getLong(e); val value = getDouble(idx); ret.putScalar(e, value); } return ret; } else if(indices.rows() == rank()) { INDArray ret = Nd4j.create(this.dataType(), indices.columns()); for(int i = 0; i < indices.columns(); i++) { int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); val v = getDouble(specifiedIndex); ret.putScalar(i, v); } return ret; } else { List arrList = new ArrayList<>(); if(indices.isMatrix() || indices.isColumnVector() || (indices.isScalar() && indices.rank() == 2)) { // we need this for compatibility with legacy code for(int i = 0; i < indices.rows(); i++) { if(i == 0) { INDArray row = indices.getRow(i); for(int j = 0; j < row.length(); j++) { arrList.add(slice(row.getInt(j))); } } else { INDArray row = indices.slice(i); for(int j = 0; j < row.length(); j++) { INDArray put = arrList.get(j).slice(row.getInt(j)); put = put.reshape(Longs.concat(new long[]{1},put.shape())); arrList.set(j,put); } } } } else if(indices.isRowVector()) { for(int i = 0; i < indices.length(); i++) { INDArray add = slice(indices.getInt(i)); add = add.reshape(Longs.concat(new long[] {1,},add.shape())); arrList.add(add); } } return Nd4j.concat(0,arrList.toArray(new INDArray[arrList.size()])); } } @Override public INDArray put(INDArray indices, INDArray element) { if(indices.rank() > 2) { throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); } if(indices.rows() == rank()) { NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape()); for(int i = 0; i < indices.columns(); i++) { int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next())); } } else { List arrList = new ArrayList<>(); if(indices.isMatrix() || indices.isColumnVector()) { for(int i = 0; i < indices.rows(); i++) { INDArray row = indices.getRow(i); for(int j = 0; j < row.length(); j++) { INDArray slice = slice(row.getInt(j)); Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice})); arrList.add(slice(row.getInt(j))); } } } else if(indices.isRowVector()) { for(int i = 0; i < indices.length(); i++) { arrList.add(slice(indices.getInt(i))); } } } return this; } @Override public INDArray put(INDArrayIndex[] indices, INDArray element) { Nd4j.getCompressor().autoDecompress(this); boolean isSpecifiedIndex = false; for(INDArrayIndex idx : indices){ if(idx instanceof SpecifiedIndex){ isSpecifiedIndex = true; break; } } if(!isSpecifiedIndex){ return get(indices).assign(element); } else { //Can't get a view, so we'll do it in subsets instead // This is inefficient, but it is correct... int numSpecified = 0; List specifiedIdxs = new ArrayList<>(); List specifiedIdxDims = new ArrayList<>(); INDArrayIndex[] destinationIndices = indices.clone(); //Shallow clone INDArrayIndex[] sourceIndices = indices.clone(); for( int i=0; i can't use point(1) on [1,x,y] sourceIndices[i] = NDArrayIndex.point(0); } } int[] counts = new int[specifiedIdxs.size()]; int[] dims = new int[specifiedIdxDims.size()]; for( int i=0; i= this.length()) throw new ND4JIllegalStateException("Index can't be greater then array length"); if (i < 0) i += this.length(); long idx = this.isScalar() ? 0 : Shape.getOffset(jvmShapeInfo.javaShapeInformation, Shape.ind2subC(this.shape(), i)); val buffer = Nd4j.createBuffer( this.data(), this.data().originalOffset() + idx, 1); val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1,'c', dataType(), false); return Nd4j.createArrayFromShapeBuffer(buffer, shape); } /** * Do a row wise op (a,s,m,d) * a : add * s : subtract * m : multiply * d : divide * h : reverse subtraction * t : reverse division * * @param columnVector the column vector * @param operation the operation * @return */ protected INDArray doColumnWise(INDArray columnVector, char operation) { Nd4j.getCompressor().autoDecompress(this); if(columnVector.isScalar()) { switch (operation) { case 'a': addi(columnVector.getDouble(0)); break; case 'p': assign(columnVector.getDouble(0)); break; case 's': subi(columnVector.getDouble(0)); break; case 'm': muli(columnVector.getDouble(0)); break; case 'd': divi(columnVector.getDouble(0)); break; case 'h': rsubi(columnVector.getDouble(0)); break; case 't': rdivi(columnVector.getDouble(0)); break; } return this; } else if(isScalar()) { switch (operation) { case 'a': return columnVector.addi(getDouble(0)); case 'p': return columnVector.assign(getDouble(0)); case 's': return columnVector.subi(getDouble(0)); case 'm': return columnVector.muli(getDouble(0)); case 'd': return columnVector.divi(getDouble(0)); case 'h': return columnVector.rsubi(getDouble(0)); case 't': return columnVector.rdivi(getDouble(0)); } } //Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0) //Or, simply require it to be a rank 1 vector if ((!columnVector.isColumnVector() && columnVector.rank() > 1) || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) { throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")"); } if (columnVector.data().sameUnderlyingData(data())) return doColumnWise(columnVector.dup(), operation); if (equalShapes(columnVector)) { switch (operation) { case 'a': addi(columnVector); break; case 'p': assign(columnVector); break; case 's': subi(columnVector); break; case 'm': muli(columnVector); break; case 'd': divi(columnVector); break; case 'h': rsubi(columnVector); break; case 't': rdivi(columnVector); break; } return this; } if (rows() == 1 && columnVector.isScalar()) { applyScalarOp(columnVector, operation); } else { // special optimization case, broadcast turns into ScalarOp Along Dimension if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'c' && columnVector.elementWiseStride() == 1) { switch (operation) { case 'a': { ScalarAdd op = new ScalarAdd(this, columnVector, this, 0.0); op.setDimension(1); Nd4j.getExecutioner().exec(op); break; } case 'p': { ScalarSet op = new ScalarSet(this, columnVector, this, 0.0); op.setDimension(1); Nd4j.getExecutioner().exec(op); break; } case 's': { ScalarSubtraction op = new ScalarSubtraction(this, columnVector, this, 0.0); op.setDimension(1); Nd4j.getExecutioner().exec(op); break; } case 'm': { ScalarMultiplication op = new ScalarMultiplication(this, columnVector, this, 0.0); op.setDimension(1); Nd4j.getExecutioner().exec(op); break; } case 'd': { ScalarDivision op = new ScalarDivision(this, columnVector, this, 0.0); op.setDimension(1); Nd4j.getExecutioner().exec(op); break; } case 'h': { ScalarReverseSubtraction op = new ScalarReverseSubtraction(this, columnVector, this, 0.0); op.setDimension(1); Nd4j.getExecutioner().exec(op); break; } case 't': { ScalarReverseDivision op = new ScalarReverseDivision(this, columnVector, this, 0.0); op.setDimension(1); Nd4j.getExecutioner().exec(op); break; } } } else { applyBroadcastOp(columnVector, operation); } } return this; } /** * Do a row wise op (a,s,m,d) * a : add * s : subtract * m : multiply * d : divide * h : reverse subtraction * t : reverse division * * @param rowVector the row vector * @param operation the operation * @return */ protected INDArray doRowWise(INDArray rowVector, final char operation) { Nd4j.getCompressor().autoDecompress(this); if(rowVector.isScalar()) { switch (operation) { case 'a': addi(rowVector.getDouble(0)); break; case 'p': assign(rowVector.getDouble(0)); break; case 's': subi(rowVector.getDouble(0)); break; case 'm': muli(rowVector.getDouble(0)); break; case 'd': divi(rowVector.getDouble(0)); break; case 'h': rsubi(rowVector.getDouble(0)); break; case 't': rdivi(rowVector.getDouble(0)); break; } return this; } else if(isScalar()) { switch (operation) { case 'a': return rowVector.addi(getDouble(0)); case 'p': return rowVector.assign(getDouble(0)); case 's': return rowVector.subi(getDouble(0)); case 'm': return rowVector.muli(getDouble(0)); case 'd': return rowVector.divi(getDouble(0)); case 'h': return rowVector.rsubi(getDouble(0)); case 't': return rowVector.rdivi(getDouble(0)); } } //Input validation: require (a) rowVector to actually be a row vector, and (b) this.size(1) to match rowVector.size(1) if (!rowVector.isRowVector() || this.rank() > 1 && rowVector.rank() > 1 && this.size(1) != rowVector.size(1) || rowVector.length() <= 1) { throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + ", row vector shape =" + Arrays.toString(rowVector.shape()) + ")"); } if (rowVector.data().sameUnderlyingData(data())) return doRowWise(rowVector.dup(), operation); if (isVector()) { switch (operation) { case 'a': addi(rowVector); break; case 'p': assign(rowVector); break; case 's': subi(rowVector); break; case 'm': muli(rowVector); break; case 'd': divi(rowVector); break; case 'h': rsubi(rowVector); break; case 't': rdivi(rowVector); break; } return this; } if (rank() == 2 && columns() == 1 && rowVector.isScalar()) { applyScalarOp(rowVector, operation); } else { // special optimization case, broadcast turns into ScalarOp Along Dimension if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'f' && rowVector.elementWiseStride() == 1) { switch (operation) { case 'a': { ScalarAdd op = new ScalarAdd(this, rowVector, this, 0.0); op.setDimension(0); Nd4j.getExecutioner().exec(op); break; } case 'p': { ScalarSet op = new ScalarSet(this, rowVector, this, 0.0); op.setDimension(0); Nd4j.getExecutioner().exec(op); break; } case 's': { ScalarSubtraction op = new ScalarSubtraction(this, rowVector, this, 0.0); op.setDimension(0); Nd4j.getExecutioner().exec(op); break; } case 'm': { ScalarMultiplication op = new ScalarMultiplication(this, rowVector, this, 0.0); op.setDimension(0); Nd4j.getExecutioner().exec(op); break; } case 'd': { ScalarDivision op = new ScalarDivision(this, rowVector, this, 0.0); op.setDimension(0); Nd4j.getExecutioner().exec(op); break; } case 'h': { ScalarReverseSubtraction op = new ScalarReverseSubtraction(this, rowVector, this, 0.0); op.setDimension(0); Nd4j.getExecutioner().exec(op); break; } case 't': { ScalarReverseDivision op = new ScalarReverseDivision(this, rowVector, this, 0.0); op.setDimension(0); Nd4j.getExecutioner().exec(op); break; } } } else { applyBroadcastOp(rowVector, operation); } } return this; } private void applyBroadcastOp(INDArray vector, final char operation) { Nd4j.getCompressor().autoDecompress(this); int alongDimension = Shape.isRowVectorShape(vector.shape()) ? 1 : 0; // FIXME: probably this is wrong, because strict equality is always false in current DataBuffer mechanics if (this.data() == vector.data()) vector = vector.dup(); switch (operation) { case 'a': Nd4j.getExecutioner().exec(new BroadcastAddOp(this, vector, this, alongDimension)); return; case 's': Nd4j.getExecutioner().exec(new BroadcastSubOp(this, vector, this, alongDimension)); return; case 'm': Nd4j.getExecutioner().exec(new BroadcastMulOp(this, vector, this, alongDimension)); return; case 'd': Nd4j.getExecutioner().exec(new BroadcastDivOp(this, vector, this, alongDimension)); return; case 'h': Nd4j.getExecutioner().exec(new BroadcastRSubOp(this, vector, this, alongDimension)); return; case 't': Nd4j.getExecutioner().exec(new BroadcastRDivOp(this, vector, this, alongDimension)); return; case 'p': Nd4j.getExecutioner().exec(new BroadcastCopyOp(this, vector, this, alongDimension)); return; default: throw new UnsupportedOperationException("Unknown operation: " + operation); } } private void applyScalarOp(INDArray vector, char operation) { Nd4j.getCompressor().autoDecompress(this); switch (operation) { case 'a': addi(vector.getDouble(0)); break; case 's': subi(vector.getDouble(0)); break; case 'm': muli(vector.getDouble(0)); break; case 'd': divi(vector.getDouble(0)); break; case 'h': rsubi(vector.getDouble(0)); break; case 't': rdivi(vector.getDouble(0)); break; } } protected DataBuffer shapeOf() { // if (shape == null) // shape = Shape.shapeOf(shapeInfoDataBuffer()); // return shape; return Shape.shapeOf(shapeInfoDataBuffer()); } protected DataBuffer strideOf() { // if (stride == null) // stride = Shape.stride(shapeInfoDataBuffer()); // return stride; return Shape.stride(shapeInfoDataBuffer()); } @Override public int stride(int dimension) { int rank = jvmShapeInfo.rank; Preconditions.checkArgument(dimension < rank, "Cannot get stride for dimension %s from rank %s array: " + "dimension indices must be in range -rank <= dimension < rank", dimension, rank); if (dimension < 0) return (int) stride()[dimension + rank]; return (int) stride()[dimension]; } @Override public INDArray rdiviColumnVector(INDArray columnVector) { validateNumericalArray("rdiviColumnVector", false); return doColumnWise(columnVector, 't'); } @Override public INDArray rdivColumnVector(INDArray columnVector) { validateNumericalArray("rdivColumnVector", false); return dup().rdiviColumnVector(columnVector); } @Override public INDArray rdiviRowVector(INDArray rowVector) { validateNumericalArray("rdiviRowVector", false); return doRowWise(rowVector, 't'); } @Override public INDArray rdivRowVector(INDArray rowVector) { validateNumericalArray("rdivRowVector", false); return dup().rdiviRowVector(rowVector); } @Override public INDArray rsubiColumnVector(INDArray columnVector) { validateNumericalArray("rsubiColumnVector", false); return doColumnWise(columnVector, 'h'); } @Override public INDArray rsubColumnVector(INDArray columnVector) { validateNumericalArray("rsubColumnVector", false); return dup().rsubiColumnVector(columnVector); } @Override public INDArray rsubiRowVector(INDArray rowVector) { validateNumericalArray("rsubiRowVector", false); return doRowWise(rowVector, 'h'); } @Override public INDArray rsubRowVector(INDArray rowVector) { validateNumericalArray("rsubRowVector", false); return dup().rsubiRowVector(rowVector); } @Override public INDArray put(int i, INDArray element) { Preconditions.checkArgument(element.isScalar(), "Element must be a scalar: element has shape %ndShape", element); return putScalar(i, element.getDouble(0)); } @Override public INDArray diviColumnVector(INDArray columnVector) { validateNumericalArray("diviColumnVector", false); return doColumnWise(columnVector, 'd'); } @Override public INDArray divColumnVector(INDArray columnVector) { validateNumericalArray("divColumnVector", false); return dup().diviColumnVector(columnVector); } @Override public INDArray diviRowVector(INDArray rowVector) { validateNumericalArray("diviRowVector", false); return doRowWise(rowVector, 'd'); } @Override public INDArray divRowVector(INDArray rowVector) { validateNumericalArray("divRowVector", false); return dup().diviRowVector(rowVector); } @Override public INDArray muliColumnVector(INDArray columnVector) { validateNumericalArray("muliColumnVector", false); return doColumnWise(columnVector, 'm'); } @Override public INDArray mulColumnVector(INDArray columnVector) { validateNumericalArray("mulColumnVector", false); return dup().muliColumnVector(columnVector); } @Override public INDArray muliRowVector(INDArray rowVector) { validateNumericalArray("muliRowVector", false); return doRowWise(rowVector, 'm'); } @Override public INDArray mulRowVector(INDArray rowVector) { validateNumericalArray("mulRowVector", false); return dup().muliRowVector(rowVector); } @Override public INDArray subiColumnVector(INDArray columnVector) { validateNumericalArray("subiColumnVector", false); return doColumnWise(columnVector, 's'); } @Override public INDArray subColumnVector(INDArray columnVector) { validateNumericalArray("subColumnVector", false); return dup().subiColumnVector(columnVector); } @Override public INDArray subiRowVector(INDArray rowVector) { validateNumericalArray("subiRowVector", false); return doRowWise(rowVector, 's'); } @Override public INDArray subRowVector(INDArray rowVector) { validateNumericalArray("subRowVector", false); return dup().subiRowVector(rowVector); } @Override public INDArray addiColumnVector(INDArray columnVector) { validateNumericalArray("addiColumnVector", false); return doColumnWise(columnVector, 'a'); } @Override public INDArray putiColumnVector(INDArray columnVector) { return doColumnWise(columnVector, 'p'); } @Override public INDArray addColumnVector(INDArray columnVector) { validateNumericalArray("addColumnVector", false); return dup().addiColumnVector(columnVector); } @Override public INDArray addiRowVector(INDArray rowVector) { validateNumericalArray("addiRowVector", false); return doRowWise(rowVector, 'a'); } @Override public INDArray putiRowVector(INDArray rowVector) { validateNumericalArray("putiRowVector", false); return doRowWise(rowVector, 'p'); } @Override public INDArray addRowVector(INDArray rowVector) { validateNumericalArray("addRowVector", false); return dup().addiRowVector(rowVector); } @Override public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { return mMulTranspose.exec(this, other, result); } @Override public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) { return mMulTranspose.exec(this, other, null); } @Override public INDArray mmul(INDArray other, char resultOrder) { Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', "Order must be either 'c' or 'f', but [" + resultOrder + "] was given"); Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType()); // FIXME: add support for 3D+ here? long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()}; INDArray result = createUninitialized(this.dataType(), shape, resultOrder); if (result.isScalar()) return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1); return mmuli(other, result); } @Override public INDArray mmul(INDArray other) { return mmul(other, (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c'); } protected INDArray create(int[] shape, char ordering) { return Nd4j.create(shape, ordering); } @Override public double[][] toDoubleMatrix() { if(!isMatrix()) { throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); } if (this.size(0) > Integer.MAX_VALUE || this.size(1) > Integer.MAX_VALUE) throw new ND4JArraySizeException(); double[][] ret = new double[rows()][columns()]; for(int i = 0; i < ret.length; i++) { ret[i] = getRow(i).dup().data().asDouble(); } return ret; } @Override public double[] toDoubleVector() { if(!isVectorOrScalar()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); } return dup().data().asDouble(); } @Override public float[] toFloatVector() { if(!isVectorOrScalar()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); } return dup().data().asFloat(); } @Override public float[][] toFloatMatrix() { if(!isMatrix()) { throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); } if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); float[][] ret = new float[(int) rows()][ (int) columns()]; for(int i = 0; i < ret.length; i++) { ret[i] = getRow(i).dup().data().asFloat(); } return ret; } @Override public int[] toIntVector() { if (isEmpty()) return new int[0]; if(!isVectorOrScalar()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); } if(isView() || elementWiseStride() != 1){ return dup().data().asInt(); } return data().asInt(); } @Override public long[] toLongVector() { if(!isVectorOrScalar()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); } if(isView() || elementWiseStride() != 1){ return dup().data().asLong(); } return data().asLong(); } @Override public long[][] toLongMatrix() { if(!isMatrix()) { throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); } if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); long[][] ret = new long[(int) rows()][(int) columns()]; for(int i = 0; i < ret.length; i++) { ret[i] = getRow(i).dup().data().asLong(); } return ret; } @Override public int[][] toIntMatrix() { if(!isMatrix()) { throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); } if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); int[][] ret = new int[(int) rows()][(int) columns()]; for(int i = 0; i < ret.length; i++) { ret[i] = getRow(i).dup().data().asInt(); } return ret; } /** * Perform an copy matrix multiplication * * @param other the other matrix to perform matrix multiply with * @param result the result ndarray * @return the result of the matrix multiplication */ @Override public INDArray mmul(INDArray other, INDArray result) { return mmuli(other, result); } @Override public INDArray div(INDArray other) { if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); } else { return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering())); } } @Override public INDArray div(INDArray other, INDArray result) { validateNumericalArray("div", true); return divi(other, result); } @Override public INDArray mul(INDArray other) { validateNumericalArray("mul", false); if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { return muli(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); } else { val z = Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering()); return muli(other, z); } } @Override public INDArray mul(INDArray other, INDArray result) { return muli(other, result); } @Override public INDArray sub(INDArray other) { validateNumericalArray("sub", false); if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); } else { return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering())); } } @Override public INDArray sub(INDArray other, INDArray result) { return subi(other, result); } @Override public INDArray add(INDArray other) { validateNumericalArray("add", false); if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); } else { return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering())); } } @Override public INDArray add(INDArray other, INDArray result) { validateNumericalArray("add", false); return addi(other, result); } @Override public INDArray mmuli(INDArray other, MMulTranspose transpose) { validateNumericalArray("mmuli", false); return dup().mmuli(other, this,transpose); } @Override public INDArray mmuli(INDArray other) { validateNumericalArray("mmuli", false); return dup().mmuli(other, this); } @Override public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) { return transpose.exec(this, other, result); } @Override public INDArray mmuli(INDArray other, INDArray result) { validateNumericalArray("mmuli", false); LinAlgExceptions.assertMultiplies(this, other); if(other.rank() == 1){ //GEMV edge case Preconditions.checkState(result.length() == this.size(0) && this.size(1) == other.size(0), "Invalid matrix multiplication: %ndShape x %ndShape with result shape %ndShape", this, other, result); } else { //Standard case Preconditions.checkState( result.rank() == 2 && result.size(0) == this.size(0) && result.size(1) == other.size(1), "Invalid result array shape: expected shape [%s,%s], got shape %ndShape result array for %ndShape x %ndShape", this.size(0), other.size(1), result, this, other); } if (other.isScalar()) { return muli(other.getDouble(0), result); } if (isScalar()) { return other.muli(getDouble(0), result); } /* check sizes and resize if necessary */ if (result == this || result == other) { /* actually, blas cannot do multiplications in-place. Therefore, we will fake by * allocating a temporary object on the side and copy the result later. */ INDArray temp = Nd4j.create(result.dataType(), result.shape(), Nd4j.getStrides(result.shape(), 'f'), 'f'); if (other.columns() == 1 || other.rank() == 1) { Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(result), BlasBufferUtil.getCharForTranspose(this), 1.0, this, other, 0.0, temp); } else { Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(result), BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(temp), 1.0, this, other, 0.0, temp); } result.assign(temp); } else { //We require that the result array is 'f' (fortran) order // However, user might have called mmuli with a c order array for the result // In which case, we need to allocate a temporary f order array, and later do an assign to the real result array boolean requiresTemp = result.ordering() != 'f' || result.isView() || !Shape.hasDefaultStridesForShape(result); INDArray gemmResultArr; if (requiresTemp) { //Can use createUninitialized due to beta==0.0 parameter in gemm gemmResultArr = Nd4j.createUninitialized(result.dataType(), result.shape(), 'f'); } else { gemmResultArr = result; } if (other.columns() == 1 || other.rank() == 1) { Nd4j.getBlasWrapper().level2().gemv( ordering(), BlasBufferUtil.getCharForTranspose(other), 1.0, this, other, 0.0, gemmResultArr); } else { //gemm doesn't support strides so vectors and views //don't work Nd4j.getBlasWrapper().level3().gemm(ordering(), BlasBufferUtil.getCharForTranspose(other), BlasBufferUtil.getCharForTranspose(gemmResultArr), 1.0, this, other, 0.0, gemmResultArr); } if (requiresTemp) { result.assign(gemmResultArr); } } // 1D edge case: reshape back to vector if (other.rank() == 1) result = result.reshape(result.length()); return result; } private INDArray create(int[] shape, int[] stride) { return Nd4j.create(shape, stride); } @Override public INDArray divi(INDArray other) { return divi(other, this); } @Override public INDArray divi(INDArray other, INDArray result) { validateNumericalArray("divi", false); Shape.assertBroadcastable("divi", this, other, result); Nd4j.exec(new DivOp(this, other, result)); return result; } @Override public INDArray muli(INDArray other) { return muli(other, this); } @Override public INDArray muli(INDArray other, INDArray result) { validateNumericalArray("muli", false); Shape.assertBroadcastable("muli", this, other, result); Nd4j.exec(new MulOp(this, other, result)); return result; } @Override public INDArray subi(INDArray other) { return subi(other, this); } /** * in place subtraction of two matrices * * @param other the second ndarray to subtract * @param result the result ndarray * @return the result of the subtraction */ @Override public INDArray subi(INDArray other, INDArray result) { validateNumericalArray("subi", false); Shape.assertBroadcastable("subi", this, other, result); Nd4j.exec(new SubOp(this, other, result)); return result; } @Override public INDArray addi(INDArray other) { return addi(other, this); } @Override public INDArray addi(INDArray other, INDArray result) { validateNumericalArray("addi", false); Shape.assertBroadcastable("addi", this, other, result); Nd4j.exec(new AddOp(this, other, result)); return result; } @Override public INDArray normmax(boolean keepDims, int... dimension) { validateNumericalArray("normmax", false); return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); } @Override public INDArray normmax(int... dimension) { return normmax(false, dimension); } @Override public INDArray rdiv(INDArray other) { validateNumericalArray("rdiv", false); if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { return rdivi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); } else { return rdivi(other, this.ulike()); } } @Override public INDArray rdivi(INDArray other) { return rdivi(other, this); } @Override public INDArray rdiv(INDArray other, INDArray result) { validateNumericalArray("rdiv", false); return dup().rdivi(other, result); } @Override public INDArray rdivi(INDArray other, INDArray result) { validateNumericalArray("rdivi", false); Shape.assertBroadcastable("rdivi", this, other, result); Nd4j.exec(new RDivOp(this, other, result)); return result; } @Override public INDArray rsub(INDArray other, INDArray result) { validateNumericalArray("rsub", false); return rsubi(other, result); } @Override public INDArray rsub(INDArray other) { validateNumericalArray("rsub", false); if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { return rsubi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); } else { return rsubi(other, this.ulike()); } } @Override public INDArray rsubi(INDArray other) { return rsubi(other, this); } @Override public INDArray rsubi(INDArray other, INDArray result) { validateNumericalArray("rsubi", false); Shape.assertBroadcastable("rsubi", this, other, result); Nd4j.exec(new RSubOp(this, other, result)); return result; } @Override public INDArray assign(Number value) { Preconditions.checkState(dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, "Only values 0 or 1 are allowed for scalar " + "assign on boolean arrays: got value %s on to assign to boolean array with shape %ndShape", value, this); Nd4j.getExecutioner().exec(new ScalarSet(this, value)); return this; } @Override public INDArray assign(boolean value) { return assign(value ? 1 : 0); } @Override public INDArray assignIf(INDArray arr, Condition condition) { BooleanIndexing.assignIf(this, arr, condition); return this; } @Override public INDArray replaceWhere(INDArray arr, Condition condition) { Nd4j.getCompressor().autoDecompress(this); BooleanIndexing.replaceWhere(this, arr, condition); return this; } @Override @Deprecated //TODO: Investigate. Not deprecated in the base interface. public long linearIndex(long i) { long idx = i; for (int j = 0; j < jvmShapeInfo.rank - 1; j++) { if (size((int) i) == 1) continue; idx += i * stride(j); } return Shape.offset(jvmShapeInfo.javaShapeInformation) + (idx); } @Override public INDArray slice(long slice) { Nd4j.getCompressor().autoDecompress(this); long slices = slices(); if (slice >= slices) throw new IllegalArgumentException("Illegal slice " + slice); if (jvmShapeInfo.rank == 0 ) { throw new IllegalArgumentException("Can't slice a 0-d NDArray"); } if (slice < 0) slice += rank(); INDArrayIndex[] indexes = new INDArrayIndex[rank()]; indexes[0] = NDArrayIndex.point(slice); for (int i = 1; i < rank(); i++) { indexes[i] = NDArrayIndex.all(); } return get(indexes); } protected INDArray createScalarForIndex(long i, boolean applyOffset) { if(isVector()) return getScalar(i); return Nd4j.create(data(), new long[] {1, 1}, new long[] {1, 1}, i); } protected INDArray createScalar(double d) { return Nd4j.scalar(d); } @Override public int getTrailingOnes() { int numLeadingOnes = 0; for (int i = rank() - 1; i > 0; i--) { if (size(i) == 1) numLeadingOnes++; } return numLeadingOnes; } @Override public int getLeadingOnes() { int numLeadingOnes = 0; for (int i = 0; i < rank(); i++) { if (size(i) == 1) numLeadingOnes++; } return numLeadingOnes; } @Override public INDArray slice(long slice, int dimension) { Nd4j.getCompressor().autoDecompress(this); long slices = size(dimension); if (slice >= slices) throw new IllegalArgumentException("Illegal slice " + slice); if (jvmShapeInfo.rank == 0) { if (slice == 0) return createScalarForIndex(slice, true); else throw new IllegalArgumentException("Can't slice a 0-d NDArray"); } if (slice < 0) slice += rank(); INDArrayIndex[] indexes = new INDArrayIndex[rank()]; indexes[dimension] = NDArrayIndex.point(slice); for (int i = 0; i < rank(); i++) { if (i != dimension) indexes[i] = NDArrayIndex.all(); } return get(indexes); } @Override public INDArray getScalar(int[] indexes) { if (indexes.length > rank()) throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); for (int i = 0; i < indexes.length; i++) { if (indexes[i] < 0) indexes[i] += this.size(i); } long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); val buffer = Nd4j.createBuffer(this.data(), idx, 1); val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1, 'c', this.dataType(), false); return Nd4j.createArrayFromShapeBuffer(buffer, shape); } @Override public INDArray getScalar(long... indexes) { if (indexes.length > rank()) throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); for (int i = 0; i < indexes.length; i++) { if (indexes[i] < 0) indexes[i] += this.size(i); } long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); val buffer = Nd4j.createBuffer(this.data(), idx, 1); val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1,'c', this.dataType(), false); return Nd4j.createArrayFromShapeBuffer(buffer, shape); } @Override public INDArray rdiv(Number n) { //return dup().rdivi(n); return rdivi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), this.ordering())); } @Override public INDArray rdivi(Number n) { return rdivi(n, this); } @Override public INDArray rsub(Number n) { validateNumericalArray("rsub", false); return rsubi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering())); } @Override public INDArray rsubi(Number n) { return rsubi(n, this); } @Override public INDArray div(Number n) { validateNumericalArray("div", false); return divi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering())); } @Override public INDArray divi(Number n) { return divi(n, this); } @Override public INDArray mul(Number n) { validateNumericalArray("mul", false); return muli(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), this.ordering())); } @Override public INDArray muli(Number n) { return muli(n, this); } @Override public INDArray sub(Number n) { validateNumericalArray("sub", false); return subi(n, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering())); } @Override public INDArray subi(Number n) { return subi(n, this); } @Override public INDArray add(Number n) { validateNumericalArray("add", false); return addi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering())); } @Override public INDArray addi(Number n) { return addi(n, this); } @Override public INDArray repmat(long[] shape) { Nd4j.getCompressor().autoDecompress(this); long rows = rows() * shape[0]; long cols = columns() * shape[1]; INDArray ret = reshape(1, length()).repeat(0, shape[0]).reshape(rows, columns()).repeat(0, shape[1]); return ret.reshape(rows, cols); } @Deprecated @Override public INDArray repmat(int[] shape) { long[] longShape = ArrayUtil.toLongArray(shape); return repmat(longShape); } @Override public INDArray repeat(int dimension, long... repeats) { Nd4j.getCompressor().autoDecompress(this); CustomOp op = DynamicCustomOp.builder("repeat") .addInputs(this) .addIntegerArguments(ArrayUtil.toInts(repeats)) //TODO int cast .build(); op.addIArgument(dimension); //Native op: last iarg is dimension LongShapeDescriptor l = op.calculateOutputShape().get(0); INDArray out = Nd4j.create(l); op.addOutputArgument(out); Nd4j.exec(op); return out; } @Override public INDArray putRow(long row, INDArray toPut) { if (isRowVector() && toPut.isVector()) { return assign(toPut); } return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut); } @Override public INDArray putColumn(int column, INDArray toPut) { Nd4j.getCompressor().autoDecompress(this); if (isColumnVector() && toPut.isVector()) { return assign(toPut); } return put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(column)}, toPut); } @Override public Number getNumber(long i){ switch (dataType()){ case DOUBLE: case FLOAT: case HALF: case BFLOAT16: return getDouble(i); case LONG: case INT: case SHORT: case UBYTE: case BYTE: case BOOL: case UINT64: case UINT32: case UINT16: return getLong(i); case UTF8: case COMPRESSED: case UNKNOWN: default: throw new UnsupportedOperationException("Cannot get number from array of datatype: " + dataType()); } } @Override public Number getNumber(long... idx){ switch (dataType()){ case DOUBLE: case FLOAT: case HALF: return getDouble(idx); case LONG: case INT: case SHORT: case UBYTE: case BYTE: case BOOL: return getLong(idx); case UTF8: case COMPRESSED: case UNKNOWN: default: throw new UnsupportedOperationException("Cannot get number from array of datatype: " + dataType()); } } @Override public double getDouble(long i) { Nd4j.getCompressor().autoDecompress(this); Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); if (i >= length()) { throw new IllegalArgumentException("Unable to get linear index " + i + ": values is greater than length (" + length() + ")"); } autoProcessScalarCall(); if (i == 0) return data().getDouble(i); long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); Shape.assertShapeLessThan(dimensions, shape()); return getDouble(dimensions); } @Override public double getDouble(long i, long j) { return getDouble(new long[] {i, j}); } @Override public float getFloat(long i) { return (float) getDouble(i); } @Override public float getFloat(long i, long j) { return (float) getDouble(i, j); } @Override public INDArray transpose() { Preconditions.checkState(rank() >= 2, "Can't transpose array with rank < 2: array shape %ndShape", this); return permute(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); } /** * * Return transposed version of this matrix. * * PLEASE NOTE: This method is NOT in place, it will return transposed copy instead. */ @Override public INDArray transposei() { Preconditions.checkState(rank() >= 2, "Can't transpose array with rank < 2: array shape %ndShape", this); return permutei(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); } protected INDArray create(DataBuffer data, int[] shape, int[] strides) { return Nd4j.create(data, shape, strides, 0, ordering()); } @Deprecated @Override public INDArray reshape(char order, int... newShape) { return reshape(order, ArrayUtil.toLongArray(newShape)); } @Override public INDArray reshape(char order, long... newShape) { return reshape(order, false, newShape); } @Override public INDArray reshape(char order, boolean enforceView, long... newShape){ Nd4j.getCompressor().autoDecompress(this); // special case for empty reshape if (this.length() == 1 && (newShape == null || newShape.length == 0) && this.elementWiseStride() == 1) { return Nd4j.create(this.data(), new int[0], new int[0], 0); } if (newShape == null || newShape.length < 1) throw new ND4JIllegalStateException( "Can't reshape(long...) without shape arguments. Got empty shape instead."); // TODO: maybe toFlatten() makes more sense here? // reshape(-1) special case if (newShape.length == 1 && newShape[0] == -1) newShape[0] = this.length(); int numberNegativesOnes = 0; long[] shape = ArrayUtil.copy(newShape); for (int i = 0; i < shape.length; i++) { if (shape[i] < 0) { if (numberNegativesOnes >= 1) throw new IllegalArgumentException("Only one dimension can be negative ones. Got shape " + Arrays.toString(newShape)); numberNegativesOnes++; int shapeLength = 1; for (int j = 0; j < shape.length; j++) if (shape[j] >= 1) shapeLength *= shape[j]; long realShape = Math.abs(length() / shapeLength); long[] thisNewShape = new long[shape.length]; for (int j = 0; j < shape.length; j++) { if (i != j) { thisNewShape[j] = shape[j]; } else thisNewShape[j] = realShape; } shape = thisNewShape; break; } } long prod = ArrayUtil.prodLong(shape); if (prod != this.length()){ throw new ND4JIllegalStateException("New shape length doesn't match original length: [" + prod + "] vs [" + this.length() + "]. Original shape: "+Arrays.toString(this.shape())+" New Shape: "+Arrays.toString(newShape)); } INDArray reshapeAttempt = Shape.newShapeNoCopy(this, shape, order == 'f'); if (reshapeAttempt != null) { // kinda strange get/set usage // reshapeAttempt.setOrder(Shape.getOrder(reshapeAttempt)); return reshapeAttempt; } if(enforceView){ throw new ND4JIllegalStateException("Unable to reshape array as view, called with enforceView=true. " + "Use enforceView=false to return a copy instead, or call reshape on a non-strided array. Array shape info: " + this.shapeInfoToString().replaceAll("\n","")); } if (order != ordering()) { INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order); ret.setData(dup(order).data()); return ret; } else if (this.isEmpty()) { return Nd4j.create(this.dataType(), shape); } else { INDArray ret = this.dup(order); return Nd4j.create(ret.data(), shape); } } @Override public double getDoubleUnsafe(long offset) { return data().getDouble(offset); } @Override public INDArray putScalarUnsafe(long offset, double value) { autoProcessScalarCall(); data().put(offset, value); return this; } @Override public INDArray reshape(char order, int rows, int columns) { return reshape(order, new long[] {rows, columns}); } /** * Reshape the ndarray in to the specified dimensions, * possible errors being thrown for invalid shapes * * Note here that one dimension can be -1. * The dimension that is -1 will be inferred from the shape and * the length of the ndarray * * @param shape the shape of the ndarray. * @return the new reshaped nd array */ @Override public INDArray reshape(int[] shape) { return reshape(Nd4j.order(), shape); } @Override public INDArray reshape(long... shape) { return reshape(Nd4j.order(), shape); } @Override public INDArray prod(boolean keepDims, int... dimension) { validateNumericalArray("prod", false); return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); } @Override public INDArray prod(int... dimension) { return prod(false, dimension); } @Override public INDArray mean(boolean keepDims, int... dimension) { validateNumericalArray("mean", false); return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); } @Override public INDArray mean(int... dimension) { return mean(false, dimension); } @Override public INDArray amean(int... dimension) { validateNumericalArray("amean", false); return Nd4j.getExecutioner().exec(new AMean(this, dimension)); } @Override public INDArray mean(@NonNull INDArray result, boolean keepDims, int... dimension) { validateNumericalArray("mean", false); return Nd4j.getExecutioner().exec(new Mean(this, result, keepDims, dimension)); } @Override public INDArray mean(@NonNull INDArray result, int... dimension) { return mean(result, false, dimension); } @Override public INDArray var(int... dimension) { validateNumericalArray("var", false); return Nd4j.getExecutioner().exec(new Variance(this, dimension)); } @Override public INDArray var(boolean biasCorrected, int... dimension) { validateNumericalArray("var", false); return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); } @Override public INDArray max(boolean keepDims, int... dimension) { validateNumericalArray("max", false); return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); } @Override public INDArray max(int... dimension) { return max(false, dimension); } @Override public INDArray amax(int... dimension) { validateNumericalArray("amax", false); return Nd4j.getExecutioner().exec(new AMax(this, dimension)); } @Override public INDArray min(boolean keepDims, int... dimension) { validateNumericalArray("min", false); return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); } @Override public INDArray min(int... dimension) { return min(false, dimension); } @Override public INDArray amin(int... dimension) { validateNumericalArray("amin", false); return Nd4j.getExecutioner().exec(new AMin(this, dimension)); } @Override public INDArray sum(int... dimension) { validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, dimension)); } @Override public INDArray sum(boolean keepDim, int... dimension) { validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); } @Override public INDArray entropy(int... dimension) { validateNumericalArray("entropy", false); return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); } @Override public INDArray shannonEntropy(int... dimension) { validateNumericalArray("shannonEntropy", false); return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); } @Override public INDArray logEntropy(int... dimension) { validateNumericalArray("logEntropy", false); return Nd4j.getExecutioner().exec(new LogEntropy(this, dimension)); } @Override public INDArray sum(@NonNull INDArray result, boolean keepDims, int... dimension) { validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, result, keepDims, dimension)); } @Override public INDArray sum(@NonNull INDArray result, int... dimension) { return sum(result, false, dimension); } @Override public INDArray norm1(int... dimension) { return norm1(false, dimension); } @Override public INDArray norm1(boolean keepDims, int... dimension) { validateNumericalArray("norm1", false); return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); } @Override public INDArray std(int... dimension) { return std(true, dimension); } @Override public INDArray std(boolean biasCorrected, int... dimension) { return std(biasCorrected, false, dimension); } @Override public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) { validateNumericalArray("std", false); return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected, keepDims, dimension)); } @Override public Number stdNumber(boolean biasCorrected) { validateNumericalArray("stdNumber", false); return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); } @Override public INDArray norm2(boolean keepDims, int... dimension) { validateNumericalArray("norm2", false); return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); } @Override public INDArray norm2(int... dimension) { return norm2(false, dimension); } @Override public int columns() { if (isMatrix()) return (int) size(1); else if (Shape.isColumnVectorShape(shape())) { return 1; } else if (Shape.isRowVectorShape(shape())) { return (int) length(); } throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid"); } @Override public int rows() { if (isMatrix()) return (int) size(0); else if (Shape.isRowVectorShape(shape())) { return 1; } else if (Shape.isColumnVectorShape(shape())) { return (int) length(); } throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid"); } @Override public INDArray ravel(char ordering) { Nd4j.getCompressor().autoDecompress(this); if(ordering == this.ordering() && Shape.hasDefaultStridesForShape(this)){ return reshape(ordering, length()); } return dup(ordering).reshape(ordering, length()); } @Override public INDArray ravel() { return reshape(length()); } @Override public void sliceVectors(List list) { if (isVector()) list.add(this); else { for (int i = 0; i < slices(); i++) { slice(i).sliceVectors(list); } } } @Override public INDArray reshape(long newRows, long newColumns) { return reshape(new long[] {newRows, newColumns}); } @Override public INDArray getColumn(long c) { Nd4j.getCompressor().autoDecompress(this); if (isColumnVector() && c == 0) return this; else if (isColumnVector() && c > 0) throw new IllegalArgumentException("Illegal index for column"); Preconditions.checkArgument(this.rank() == 2, "getColumn() can be called on 2D arrays only"); return tensorAlongDimension(c, 0); } @Override public INDArray getColumn(long c, boolean keepDim) { INDArray col = getColumn(c); if(!keepDim) return col; return col.reshape(col.length(), 1); } @Override public INDArray getRows(int[] rindices) { Nd4j.getCompressor().autoDecompress(this); if (!isMatrix() && !isVector()) throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); if (isVector()) return Nd4j.pullRows(this, 1, rindices); else { INDArray ret = Nd4j.createUninitialized(this.dataType(), new long[] {rindices.length, columns()}); for (int i = 0; i < rindices.length; i++) ret.putRow(i, getRow(rindices[i])); return ret; } } @Override public INDArray get(INDArrayIndex... indexes) { Nd4j.getCompressor().autoDecompress(this); int numPoint = 0; int numInterval = 0; int numAll = 0; int numNewAxis = 0; int numSpecified = 0; for(INDArrayIndex i : indexes){ if(i instanceof PointIndex){ numPoint++; } else if(i instanceof NDArrayIndexAll){ numAll++; } else if(i instanceof IntervalIndex){ numInterval++; } else if(i instanceof NewAxis){ numNewAxis++; } else if(i instanceof SpecifiedIndex){ numSpecified++; } else { throw new IllegalStateException("Unknown index: " + i); } } // Padding remaining dimensions with all() index if too few indices provided if (indexes.length - numNewAxis < this.rank()) { val newIndexes = new INDArrayIndex[this.rank() + numNewAxis]; for (int e = 0; e < indexes.length; e++) newIndexes[e] = indexes[e]; for (int e = indexes.length; e < newIndexes.length; e++) { numAll++; newIndexes[e] = NDArrayIndex.all(); } indexes = newIndexes; } Preconditions.checkState((numPoint + numInterval + numAll + numSpecified) == rank(), "Illegal set of indices for array: need at least" + " %s point/interval/all/specified indices for rank %s array (%ndShape), got indices %s", rank(), rank(), this, indexes); int outRank = rank() + numNewAxis - numPoint; Preconditions.checkState(outRank >= 0, "Illegal set of indices for array: %ndShape, %s", this, indexes); //To work out sub-array, we need to work out 3 things: offset, shape and strides. We calculate all of these long[] outShape = new long[outRank]; long[] outStrides = new long[outRank]; long offset = offset(); //Start with existing offset if view int outIdx = 0; //Axis number counter for output array int inIdx = 0; //Axis number counter for input array for( int i=0; i= size(inIdx)) { throw new IllegalStateException("Indices are out of range: Cannot get interval index " + indexes[i] + " on array with size(" + inIdx + ")=" + size(inIdx) + ". Array shape: " + Arrays.toString(shape()) + ", indices: " + Arrays.toString(indexes)); } long stride = ii.stride(); long length = (endInc - start)/stride + 1; offset += ii.offset() * stride(inIdx); outShape[outIdx] = length; outStrides[outIdx] = ii.stride() * stride(inIdx); inIdx++; outIdx++; } else if(indexes[i] instanceof NewAxis) { //New axis: appends a 1 in shape. Axis not present in input, but is present in output outShape[outIdx] = 1; if (outIdx > 0) { //Stride doesn't matter for 1 size axis anyway... outStrides[outIdx] = outStrides[outIdx - 1]; } else { outStrides[outIdx] = 1; } outIdx++; } else if(indexes[i] instanceof SpecifiedIndex){ //Specified index: axis present in both input and output SpecifiedIndex si = (SpecifiedIndex)indexes[i]; outShape[outIdx++] = si.length(); inIdx++; //Don't care about strides for specified index, as result won't be a view } else { throw new IllegalStateException("Unknown index type: " + i); //Should never happen } } //Note: If we have specified indices, we can't return a view. Instead, we copy the specified sub-arrays from // the input array to the output array. //How? Create the output array, then do loop over the specified indices only, and copy sub-arrays for all other axes if (numSpecified > 0) { INDArray out = Nd4j.create(dataType(), outShape); //Need to copy subsets here long[] specifiedSizes = new long[numSpecified]; SpecifiedIndex[] si = new SpecifiedIndex[numSpecified]; int j=0; for( int i=0; i replace with loop + point // ii. new axis indices -> ignore/exclude (don't appear in input) // iii. interval indices -> replace with all //(2) Get from output: requested indices, except for: // i. point indices -> ignore/exclude (don't appear in output) // ii. new axis indices -> replace with point(0) INDArrayIndex[] pointIdxsIn = new INDArrayIndex[indexes.length - numNewAxis]; //Indices for source (this array) int[] specifiedAxisIn = new int[numSpecified]; int specCount = 0; j = 0; for( int i=0; i 0) throw new IllegalArgumentException("Illegal index for row: requested row " + r + " but this.size(0)=" + this.size(0)); Preconditions.checkArgument(rank() == 2, "getRow() can be called on 2D arrays only"); Preconditions.checkArgument(r < rows(), "Row index must be smaller than total number of rows"); return tensorAlongDimension(r, 1); } @Override public INDArray getRow(long r, boolean keepDim) { INDArray row = getRow(r); if(!keepDim) return row; return row.reshape(1, row.length()); } public boolean equalsWithEps(Object o, double eps) { Nd4j.getCompressor().autoDecompress(this); if (o == null) return false; if (!(o instanceof INDArray)) return false; INDArray n = (INDArray) o; Nd4j.getCompressor().autoDecompress(n); if (n == this) return true; if (this.rank() != n.rank()) return false; if (this.length() != n.length()) return false; if (this.isEmpty() != n.isEmpty()) return false; if (this.isEmpty() && n.isEmpty()) return Shape.shapeEquals(this.shape(), n.shape()); if (this.dataType() != n.dataType()) return false; // meh if (this.dataType() == DataType.UTF8 && n.dataType() == DataType.UTF8) { for (long e = 0; e < this.length(); e++) { val str1 = this.getString(e); val str2 = n.getString(e); if (!str1.equals(str2)) return false; } return true; } //epsilon equals if (isScalar() && n.isScalar()) { if (isZ()) { val val = getLong(0); val val2 = n.getLong(0); return val == val2; } else if (isR()) { val val = getDouble(0); val val2 = n.getDouble(0); if (Double.isNaN(val) != Double.isNaN(val2)) return false; return Math.abs(val - val2) < eps; } else if (isB()) { val val = getInt(0); val val2 = n.getInt(0); return val == val2; } } else if (isVector() && n.isVector()) { val op = new EqualsWithEps(this, n, eps); Nd4j.exec(op); val diff = op.z().getDouble(0); return diff < 0.5; } if (!Arrays.equals(this.shape(), n.shape())) return false; if (!Shape.shapeEquals(shape(), n.shape())) { return false; } if (slices() != n.slices()) return false; if (n.ordering() == ordering()) { EqualsWithEps op = new EqualsWithEps(this, n, eps); Nd4j.getExecutioner().exec(op); double diff = op.z().getDouble(0); return diff < 0.5; } else { EqualsWithEps op = new EqualsWithEps(this, n, eps); Nd4j.getExecutioner().exec(op); double diff = op.z().getDouble(0); return diff < 0.5; } } @Override public boolean equalShapes(@NonNull INDArray other){ if(isEmpty() != other.isEmpty()) return false; if(rank() != other.rank()) return false; for( int i=0; i= rank()) throw new IllegalArgumentException("Invalid size: cannot get size of dimension " + dimension + " for rank " + rank() + " NDArray (array shape: " + Arrays.toString(this.shape()) + ")"); return jvmShapeInfo.shape[dimension]; } @Override public int rank() { return jvmShapeInfo.rank; } @Override public long length() { if (isEmpty()) return 0; return jvmShapeInfo.length; } @Override public INDArray broadcast(INDArray result) { Nd4j.getCompressor().autoDecompress(this); val shape = result.shape(); if (Shape.shapeEquals(shape, shape())) return this; // if we're on scalar, we can just create new array if (this.isScalar()) return Nd4j.createUninitialized(this.dataType(), shape).assign(this.getDouble(0)); boolean compatible = true; int count = shape.length - 1; int thisCount = jvmShapeInfo.rank - 1; for (int i = shape.length - 1; i > 0; i--) { if (count < 0 || thisCount < 0) break; if (shape[count] != shape()[thisCount] && shape[count] != 1 && shape()[thisCount] != 1) { compatible = false; break; } count--; thisCount--; } if (!compatible) throw new IllegalArgumentException("Incompatible broadcast from " + Arrays.toString(shape()) + " to " + Arrays.toString(shape)); long[] retShape = new long[shape.length]; List broadCastDimensions = new ArrayList<>(); List nonBroadCastDimensions = new ArrayList<>(); for (int i = 0; i < retShape.length; i++) { if (shape().length == 1) { if (i == 0) { if (i < shape().length) retShape[i] = Math.max(1, shape[i]); else retShape[i] = shape[i]; } else { if (i < shape().length) retShape[i] = Math.max(shape[i], size(i)); else retShape[i] = shape[i]; } } else { if (i < rank() && size(i) == 1) broadCastDimensions.add(i); else nonBroadCastDimensions.add(i); if (i < shape().length) retShape[i] = Math.max(shape[i], size(i)); else retShape[i] = shape[i]; } } if (isRowVector()) { //number of times to repeat each value for (int i = 0; i < result.slices(); i++) { result.putSlice(i, this); } } else if (isColumnVector()) { for (int i = 0; i < result.columns(); i++) { result.putColumn(i, this); } } else { int[] repeat = new int[shape.length]; for(int i = 0; i < shape.length; i++) { if(i < rank()) { if(size(i) == 1) repeat[i] = (int) shape[i]; else { repeat[i] = 1; } } else { repeat[i] = (int) shape[i]; } } if (this.isView()) { Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this.dup(this.ordering())},new INDArray[]{result},repeat)); } else Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this},new INDArray[]{result},repeat)); } return result; } @Override public INDArray broadcast(long... shape) { return broadcast(Nd4j.createUninitialized(this.dataType(), shape, this.ordering())); } @Deprecated @Override public INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable) { return dimShuffle(rearrange, ArrayUtil.toLongArray(newOrder), broadCastable); } /** * Dimshuffle: an extension of permute that adds the ability * to broadcast various dimensions. *

* See theano for more examples. * This will only accept integers and xs. *

* An x indicates a dimension should be broadcasted rather than permuted. * * @param rearrange the dimensions to swap to * @return the newly permuted array */ @Override public INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable) { Nd4j.getCompressor().autoDecompress(this); if (broadCastable.length != jvmShapeInfo.rank) throw new IllegalArgumentException( "The broadcastable dimensions must be the same length as the current shape"); boolean broadcast = false; Set set = new HashSet<>(); for (int i = 0; i < rearrange.length; i++) { set.add(rearrange[i]); if (rearrange[i] instanceof Integer) { Integer j = (Integer) rearrange[i]; if (j >= broadCastable.length) throw new IllegalArgumentException( "Illegal dimension, dimension must be < broadcastable.length (aka the real dimensions"); } else if (rearrange[i] instanceof Character) { Character c = (Character) rearrange[i]; if (c != 'x') throw new IllegalArgumentException("Illegal input: Must be x"); broadcast = true; } else throw new IllegalArgumentException("Only characters and integers allowed"); } //just do permute if (!broadcast) { int[] ret = new int[rearrange.length]; for (int i = 0; i < ret.length; i++) ret[i] = (Integer) rearrange[i]; return permute(ret); } else { List drop = new ArrayList<>(); for (int i = 0; i < broadCastable.length; i++) { if (!set.contains(i)) { if (broadCastable[i]) drop.add(i); else throw new IllegalArgumentException( "We can't drop the given dimension because its not broadcastable"); } } //list of dimensions to keep int[] shuffle = new int[broadCastable.length]; int count = 0; for (int i = 0; i < rearrange.length; i++) { if (rearrange[i] instanceof Integer) { shuffle[count++] = (Integer) rearrange[i]; } } List augment = new ArrayList<>(); for (int i = 0; i < rearrange.length; i++) { if (rearrange[i] instanceof Character) augment.add(i); } Integer[] augmentDims = augment.toArray(new Integer[1]); count = 0; int dropIdx = 0; int[] newShape = new int[shuffle.length + drop.size()]; for (int i = 0; i < newShape.length; i++) { if (i < shuffle.length) { newShape[count++] = shuffle[i]; } else newShape[count++] = drop.get(dropIdx++); } INDArray ret; //TODO is this correct? This was old behaviour before adding permute input check if(newShape.length == this.rank()){ ret = permute(newShape); } else { ret = dup(); } List newDims = new ArrayList<>(); long[] shape = Arrays.copyOfRange(ret.shape(), 0, shuffle.length); for (int i = 0; i < shape.length; i++) { newDims.add(shape[i]); } for (int i = 0; i < augmentDims.length; i++) { newDims.add(augmentDims[i], 1L); } long[] toReshape = ArrayUtil.toArrayLong(newDims); ret = ret.reshape(toReshape); return ret; } } @Override public INDArray permute(int... rearrange) { Preconditions.checkArgument(rearrange.length == rank(), "Incorrect number of arguments for permute function:" + " got arguments %s for rank %s array. Number of arguments must equal array rank", rearrange, rank()); Nd4j.getCompressor().autoDecompress(this); boolean alreadyInOrder = true; //IntBuffer shapeInfo = shapeInfo(); int rank = jvmShapeInfo.rank; for (int i = 0; i < rank; i++) { if (rearrange[i] != i) { alreadyInOrder = false; break; } } if (alreadyInOrder) return this; checkArrangeArray(rearrange); val newShape = doPermuteSwap(shape(), rearrange); val newStride = doPermuteSwap(stride(), rearrange); char newOrder = Shape.getOrder(newShape, newStride, 1); INDArray value = create(data(), newShape, newStride, offset(), newOrder); return value; } @Override public INDArray permutei(int... rearrange) { Preconditions.checkArgument(rearrange.length == rank(), "Incorrect number of arguments for permute function:" + " got arguments %s for rank %s array. Number of arguments must equal array rank", rearrange, rank()); boolean alreadyInOrder = true; val shapeInfo = shapeInfo(); int rank = jvmShapeInfo.rank; for (int i = 0; i < rank; i++) { if (rearrange[i] != i) { alreadyInOrder = false; break; } } if (alreadyInOrder) return this; checkArrangeArray(rearrange); val newShape = doPermuteSwap(shape(), rearrange); val newStride = doPermuteSwap(stride(), rearrange); char newOrder = Shape.getOrder(newShape, newStride, 1); val ews = shapeInfo.get(2 * rank + 2); val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty()); setShapeInformation(si); if (shapeInfo.get(2 * rank + 2) > 0) { //for the backend to work - no ews for permutei //^^ not true anymore? Not sure here. Marking this for raver setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, 0, newOrder, dataType(), isEmpty())); } //this.shape = null; //this.stride = null; return this; } @Deprecated protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) { val ret = new long[rearrange.length]; for (int i = 0; i < rearrange.length; i++) { ret[i] = shape.get(rearrange[i]); } return ret; } @Deprecated protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) { int[] ret = new int[rearrange.length]; for (int i = 0; i < rearrange.length; i++) { ret[i] = shape.get(rearrange[i]); } return ret; } @Deprecated protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) { int[] ret = new int[rearrange.length]; for (int i = 0; i < rearrange.length; i++) { ret[i] = shape.getInt(rearrange[i]); } return ret; } protected long[] doPermuteSwap(long[] shape, int[] rearrange) { val ret = new long[rearrange.length]; for (int i = 0; i < rearrange.length; i++) { ret[i] = shape[rearrange[i]]; } return ret; } protected void checkArrangeArray(int[] arr) { Preconditions.checkArgument(arr.length == jvmShapeInfo.rank, "Invalid rearrangement: number of arrangement (%s) != rank (%s)", arr.length, jvmShapeInfo.rank); for (int i = 0; i < arr.length; i++) { if (arr[i] >= arr.length) throw new IllegalArgumentException("The specified dimensions can't be swapped. Given element " + i + " was >= number of dimensions"); if (arr[i] < 0) throw new IllegalArgumentException("Invalid dimension: " + i + " : negative value"); } for (int i = 0; i < arr.length; i++) { for (int j = 0; j < arr.length; j++) { if (i != j && arr[i] == arr[j]) throw new IllegalArgumentException("Permute array must have unique elements"); } } } protected void autoProcessScalarCall() { /* if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC) OpProfiler.getInstance().processScalarCall();*/ } /** * Checks whether the matrix is a vector. */ @Override public boolean isVector() { if (jvmShapeInfo.rank == 1) return true; return isRowVector() || isColumnVector(); } @Override public boolean isVectorOrScalar() { return isVector() || isScalar(); } @Override public boolean isSquare() { return isMatrix() && rows() == columns(); } @Override public boolean isRowVector() { return (rank() == 2 && rows() == 1) && length() > 1 || rank() == 1 && length() > 1; } @Override public boolean isColumnVector() { return rank() == 2 && columns() == 1 && length() > 1; } @Override public boolean isColumnVectorOrScalar() { return isColumnVector() || isScalar(); } @Override public boolean isRowVectorOrScalar() { return isRowVector() || isScalar(); } /** * Generate string representation of the matrix. * Printing will switch to scientific notation on a per element basis * - when abs value is greater than or equal to 10000 * - when abs value is less than or equal to 0.0001 and not zero * * If the number of elements in the array is greater than 1000 (by default) only the first and last three elements * in a dimension are included. This can be changed globally using {@link NDArrayStrings#setMaxPrintElements(long)} * * */ @Override public String toString() { return toString(new NDArrayStrings()); } @Override public String toString(@NonNull NDArrayStrings options){ if(wasClosed()) return ""; if (!isCompressed() && !preventUnpack) return options.format(this); else if (isCompressed() && compressDebug) return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered."; else if (preventUnpack) return "Array string unpacking is disabled."; return options.format(this); } @Override public String toString(long maxElements, boolean forceSummarize, int precision){ return toString(new NDArrayStrings(maxElements, forceSummarize, precision)); } @Override public String toStringFull(){ return toString(Long.MAX_VALUE, false, -1 * dataType().precision()); } @Override public Object element() { if (!isScalar()) throw new IllegalStateException("Unable to retrieve element from non scalar matrix"); if (data.dataType() == DataType.FLOAT) return data.getFloat(0); return data.getDouble(0); } @Override public INDArray remainder(INDArray denominator) { if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { return remainder(denominator, Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(), denominator.shape()))); } else return remainder(denominator, this.ulike()); } @Override public INDArray remainder(INDArray denominator, INDArray result) { validateNumericalArray("remainder", false); Preconditions.checkArgument(Shape.areShapesBroadcastable(this.shape(), denominator.shape()),"Shapes must be broadcastable"); val op = new RemainderOp(this, denominator, result); Nd4j.getExecutioner().exec(op); return result; } @Override public INDArray remainder(Number denominator) { return remainder(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); } @Override public INDArray remainder(Number denominator, INDArray result) { validateNumericalArray("remainder", false); ScalarRemainder op = new ScalarRemainder(this, null, result, denominator); Nd4j.getExecutioner().exec(op); return result; } @Override public INDArray remainderi(INDArray denominator) { validateNumericalArray("remainderi", false); RemainderOp op = new RemainderOp(this, denominator, this); Nd4j.getExecutioner().exec(op); return this; } @Override public INDArray remainderi(Number denominator) { validateNumericalArray("remainderi", false); ScalarRemainder op = new ScalarRemainder(this, null, this, denominator); Nd4j.getExecutioner().exec(op); return this; } @Override public INDArray fmod(INDArray denominator) { validateNumericalArray("fmod", false); if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { return fmod(denominator, Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), Shape.broadcastOutputShape(this.shape(), denominator.shape()))); } else return fmod(denominator, this.ulike()); } @Override public INDArray fmod(INDArray denominator, INDArray result) { validateNumericalArray("fmod", false); if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { val outShape = Shape.broadcastOutputShape(this.shape(), denominator.shape()); Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); Nd4j.exec(new FloorModOp(new INDArray[]{this, denominator}, new INDArray[]{result})); return result; } else { FModOp op = new FModOp(this, denominator, result); Nd4j.getExecutioner().exec(op); return result; } } @Override public INDArray fmod(Number denominator) { return fmod(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); } @Override public INDArray fmod(Number denominator, INDArray result) { validateNumericalArray("fmod", false); ScalarFMod op = new ScalarFMod(this, null, result, denominator); Nd4j.getExecutioner().exec(op); return result; } @Override public INDArray fmodi(INDArray denominator) { validateNumericalArray("fmodi", false); FModOp op = new FModOp(this, denominator, this); Nd4j.getExecutioner().exec(op); return this; } @Override public INDArray fmodi(Number denominator) { validateNumericalArray("fmodi", false); ScalarFMod op = new ScalarFMod(this, null, this, denominator); Nd4j.getExecutioner().exec(op); return this; } @Override public Iterator iterator() { return new FirstAxisIterator(this); } @Override public long originalOffset() { if (data().originalOffset() >= Integer.MAX_VALUE) throw new IllegalArgumentException("Original offset of buffer can not be >= Integer.MAX_VALUE"); return data().originalOffset(); } private void readObject(ObjectInputStream s) { try { s.defaultReadObject(); read(s); } catch (Exception e) { throw new RuntimeException(e); } } private void writeObject(ObjectOutputStream out) throws IOException { out.defaultWriteObject(); write(out); } //Custom serialization for Java serialization protected void write(ObjectOutputStream out) throws IOException { if (this.isView()) { //As per Nd4j.write, duplicate before writing to the output stream //BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here //Furthermore, because we only want to save the *actual* data for a view (not the full data), the shape info // (mainly strides, offset, element-wise stride) may be different in the duped array vs. the view array INDArray copy = this.dup(); copy.shapeInfoDataBuffer().write(out); copy.data().write(out); } else { shapeInformation.write(out); data().write(out); } } //Custom deserialization for Java serialization protected void read(ObjectInputStream s) { val headerShape = BaseDataBuffer.readHeader(s); shapeInformation = Nd4j.createBuffer(new int[Shape.shapeInfoLength(rank())]); shapeInformation.read(s, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getRight()); setShapeInformation(Pair.create(shapeInformation, shapeInformation.asLong())); val headerData = BaseDataBuffer.readHeader(s); data = Nd4j.createBuffer(headerData.getRight(), headerData.getMiddle(), false); data().read(s, headerData.getLeft(), headerData.getMiddle(), headerData.getRight()); } @Override public INDArray argMax(int... dimension) { return Nd4j.argMax(this, dimension); } @Override public boolean isAttached() { if (isEmpty()) return false; Preconditions.checkArgument(!(data == null && !isEmpty()), "Array has no buffer!"); return data.isAttached() || (data.underlyingDataBuffer() != null && data.underlyingDataBuffer().isAttached()) || (data.originalDataBuffer() != null && data.originalDataBuffer().isAttached()); } @Override public boolean isInScope() { if (!isAttached()) return true; return data.isInScope(); } @Override public INDArray detach() { if (!isAttached()) return this; WorkspaceUtils.assertValidArray(this, "Cannot detach INDArray"); Nd4j.getExecutioner().commit(); /* two options here 1) we're within some workspace 2) we're out of any workspace */ if (Nd4j.getMemoryManager().getCurrentWorkspace() == null) { if (!isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); Nd4j.getMemoryManager().memcpy(buffer, this.data()); return Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); } else { INDArray copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); copy.assign(this); Nd4j.getExecutioner().commit(); return copy; } } else { MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); Nd4j.getMemoryManager().setCurrentWorkspace(null); INDArray copy = null; if (!isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); //Pointer.memcpy(buffer.pointer(), this.data.pointer(), this.lengthLong() * Nd4j.sizeOfDataType(this.data.dataType())); Nd4j.getMemoryManager().memcpy(buffer, this.data()); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); //this.dup(this.ordering()); } else { copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); copy.assign(this); Nd4j.getExecutioner().commit(); } Nd4j.getMemoryManager().setCurrentWorkspace(workspace); return copy; } } @Override public INDArray leverage() { WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); if (!isAttached()) return this; MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); if (workspace == null) { return this.detach(); } MemoryWorkspace parentWorkspace = workspace.getParentWorkspace(); if (this.data.getParentWorkspace() == parentWorkspace) return this; // if there's no parent ws - just detach if (parentWorkspace == null) return this.detach(); else { Nd4j.getExecutioner().commit(); // temporary set parent ws as current ws Nd4j.getMemoryManager().setCurrentWorkspace(parentWorkspace); INDArray copy = null; if (!this.isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.length(), false); Nd4j.getMemoryManager().memcpy(buffer, this.data()); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); } else { copy = this.dup(this.ordering()); Nd4j.getExecutioner().commit(); } // restore current ws Nd4j.getMemoryManager().setCurrentWorkspace(workspace); return copy; } } @Override public INDArray leverageTo(String id) { return leverageTo(id, false); } @Override public INDArray leverageTo(String id, boolean enforceExistence) throws Nd4jNoSuchWorkspaceException { WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); if (!isAttached()) return this; if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) { if(enforceExistence){ throw new Nd4jNoSuchWorkspaceException(id); } else { return this; } } MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id); if (this.data.getParentWorkspace() == target) return this; Nd4j.getMemoryManager().setCurrentWorkspace(target); INDArray copy = null; if (!this.isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); Nd4j.getMemoryManager().memcpy(buffer, this.data()); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); } else { copy = this.dup(this.ordering()); Nd4j.getExecutioner().commit(); } Nd4j.getMemoryManager().setCurrentWorkspace(current); return copy; } public INDArray leverageOrDetach(String id){ if(!isAttached()){ return this; } if(!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(id)){ return detach(); } return leverageTo(id); } @Override public INDArray migrate() { return migrate(false); } @Override public INDArray migrate(boolean detachOnNoWs){ WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); if (current == null) { if(detachOnNoWs){ return detach(); } else { return this; } } INDArray copy = null; if (!this.isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); Nd4j.getMemoryManager().memcpy(buffer, this.data()); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); } else { copy = this.dup(this.ordering()); Nd4j.getExecutioner().commit(); } return copy; } @Override public Number percentileNumber(Number quantile) { validateNumericalArray("percentileNumber", false); if (quantile.intValue() < 0 || quantile.intValue() > 100) throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); if (isScalar()) return this.getDouble(0); INDArray sorted = Nd4j.sort(this.dup(this.ordering()), true); return getPercentile(quantile, sorted); } @Override public Number medianNumber() { validateNumericalArray("medianNumber", false); if(isScalar()) return getNumber(0); return percentileNumber(50); } @Override public INDArray median(int... dimension) { validateNumericalArray("median", false); //Check edge case: size 1 element. No dimension == full array if(dimension.length == 0){ return Nd4j.scalar(dataType(), medianNumber().doubleValue()); } long shapeProd = 1; for (int d : dimension) { shapeProd *= size(d); } if (shapeProd == 1) { long[] newShape = ArrayUtil.removeIndex(shape(), dimension); return dup('c').reshape('c', newShape); } return percentile(50, dimension); } protected double getPercentile(Number quantile, INDArray sorted) { validateNumericalArray("getPercentile", false); if (quantile.intValue() == 0) return sorted.getDouble(0); else if (quantile.intValue() == 100) return sorted.getDouble(sorted.length() - 1); double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); if (pos < 1) return sorted.getDouble(0); else if (pos >= sorted.length()) return sorted.getDouble(sorted.length() - 1); double fposition = FastMath.floor(pos); int position = (int)fposition; double diff = pos - fposition; double lower = sorted.getDouble(position-1); double upper = sorted.getDouble(position); return lower + diff * (upper - lower); } @Override public INDArray percentile(Number quantile, int... dimension) { validateNumericalArray("percentile", false); if (quantile.doubleValue() < 0 || quantile.doubleValue() > 100) throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); if (isScalar()) return Nd4j.scalar(this.getDouble(0)); INDArray sorted = Nd4j.getNDArrayFactory().sort(this.dup(this.ordering()), false, dimension); // there's no practical sense doing this on GPU, stride will be just size of TAD. INDArray ret = Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), sorted.tensorsAlongDimension(dimension)); for (int i = 0; i < ret.length(); i++) { ret.putScalar(i, getPercentile(quantile, sorted.tensorAlongDimension(i, dimension))); } return ret; } protected abstract int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer); @Override public int toFlatArray(FlatBufferBuilder builder) { if(isView()){ return dup(this.ordering()).toFlatArray(builder); } int shape = FlatArray.createShapeVector(builder, this.shapeInfoDataBuffer().asLong()); int buffer = this.isEmpty() ? 0 : this.dataType() == DataType.UTF8 ? stringBuffer(builder, this.data()) : FlatArray.createBufferVector(builder, this.data().asBytes()); val type = this.isEmpty() ? FlatBuffersMapper.getDataTypeAsByte(this.dataType()) : FlatBuffersMapper.getDataTypeAsByte(this.data().dataType()); int array = FlatArray.createFlatArray(builder, shape, buffer, type, ByteOrder.BE); return array; } protected static DataTypeEx convertType(DataType type) { if (type == DataType.HALF) { return DataTypeEx.FLOAT16; } else if (type == DataType.FLOAT) { return DataTypeEx.FLOAT; } else if (type == DataType.DOUBLE) { return DataTypeEx.DOUBLE; } else if(type == DataType.INT) { return DataTypeEx.INT8; } else if(type == DataType.LONG) { return DataTypeEx.INT16; } else throw new IllegalStateException("Unknown dataType: [" + type + "]"); } @Override public boolean isEmpty() { return Shape.isEmpty(jvmShapeInfo.javaShapeInformation); } @Override public long[] shapeInfoJava() { return jvmShapeInfo.javaShapeInformation; } @Override public DataType dataType() { if (data != null) return data.dataType(); val e = Shape.extras(jvmShapeInfo.javaShapeInformation); if (e != 0) { val t = ArrayOptionsHelper.dataType(jvmShapeInfo.javaShapeInformation); if (t != DataType.UNKNOWN) return t; } return DataType.UNKNOWN; } @Override public boolean isR() { val dtype = dataType(); return dtype == DataType.FLOAT || dtype == DataType.DOUBLE || dtype == DataType.HALF || dtype == DataType.BFLOAT16; } @Override public boolean isZ() { return !isR() && !isB() && !isS(); } @Override public boolean isB() { return dataType() == DataType.BOOL; } @Override public boolean isS() { return dataType() == DataType.UTF8; } @Override public INDArray castTo(DataType dataType) { if(dataType == dataType()) //No-op if correct datatype return this; if(isEmpty() && rank() == 0){ return Nd4j.empty(dataType); } val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering()); result.assign(this); return result; } @Override public boolean all() { val r = Nd4j.getExecutioner().exec(new All(this)); return r.getDouble(0) != 0.0; } @Override public boolean any() { val r = Nd4j.getExecutioner().exec(new Any(this)); return r.getDouble(0) != 0.0; } @Override public boolean none() { return !any(); } /** * Validate that the operation is being applied on a numerical array (not boolean or utf8). * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays * @param opName Operation name to print in the exception */ protected void validateNumericalArray(String opName, boolean allowEmpty){ if(dataType() == DataType.BOOL || dataType() == DataType.UTF8) throw new IllegalStateException("Cannot apply operation " + opName + " to array with " + dataType() + " datatype. Array shape: " + Arrays.toString(shape())); if(!allowEmpty && isEmpty()) throw new IllegalStateException("Cannot perform operation " + opName + " on empty array with datatype " + dataType()); } @Override public boolean closeable() { if (released || isAttached()) return false; // empty arrays have no buffer at all if (isEmpty()) return true; if (isView()) return false; return data.closeable(); } @Override public void close() { // empty arrays have no buffer at all if (released || isEmpty()) return; Nd4j.getExecutioner().commit(); if (!closeable()) throw new ND4JIllegalStateException("Can't release this INDArray"); data.close(); released = true; } @Override public INDArray like() { return Nd4j.create(this.dataType(), this.shape(), Nd4j.getStrides(this.shape(), this.ordering()), this.ordering()); } @Override public INDArray ulike() { return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); } @Override public boolean wasClosed() { // data can be null if that's empty array if (released || (data() != null && data().wasClosed())) return true; return false; } @Override public long getId(){ return arrayId; } }