All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ndarray.BaseNDArray Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.api.ndarray;


import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.primitives.Longs;
import com.google.flatbuffers.FlatBufferBuilder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.graph.ByteOrder;
import org.nd4j.graph.FlatArray;
import org.nd4j.linalg.api.blas.BlasBufferUtil;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.*;
import org.nd4j.linalg.api.iter.FirstAxisIterator;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
import org.nd4j.linalg.api.ops.impl.reduce.same.*;
import org.nd4j.linalg.api.ops.impl.reduce3.EqualsWithEps;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.broadcast.*;
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.scalar.*;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.*;
import org.nd4j.linalg.api.ops.impl.shape.Tile;
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.same.Negative;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.exception.*;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.*;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.api.memory.MemcpyDirection;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.string.NDArrayStrings;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.nd4j.linalg.util.NDArrayMath;
import org.nd4j.linalg.workspace.WorkspaceUtils;

import java.io.*;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

import static org.nd4j.linalg.factory.Nd4j.*;


@Slf4j
public abstract class BaseNDArray implements INDArray, Iterable {

    private static final long serialVersionUID = 3285982317165542614L;

    protected transient volatile DataBuffer shapeInformation;
    protected transient volatile DataBuffer data;
    //protected transient DataBuffer shape;
    //protected transient DataBuffer stride;
    protected transient boolean compressed = false;

    protected transient boolean released = false;

    // this field holds jvm copy of shapeInfo
    protected transient JvmShapeInfo jvmShapeInfo;


    private static final AtomicLong arrayCounter = new AtomicLong(0);
    protected transient final long arrayId = arrayCounter.getAndIncrement();


    //Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over
    private static final int[][] tadFinalPermuteDimensions;
    static {
        tadFinalPermuteDimensions = new int[32][0];
        tadFinalPermuteDimensions[1] = new int[] {1, 0}; //Edge case for 1d tensors: selectively apply to column vectors
        for (int i = 2; i < 32; i++) {
            tadFinalPermuteDimensions[i] = new int[i];
            for (int k = i - 1, j = 0; k >= 0; k--, j++)
                tadFinalPermuteDimensions[i][j] = k;
        }
        val t =1;
    }

    public BaseNDArray() {
    }

    @Override
    public boolean isCompressed() {
        return compressed;
    }

    @Override
    public void markAsCompressed(boolean reallyCompressed) {
        this.compressed = reallyCompressed;
    }

    /**
     *
     * @param buffer
     */
    public BaseNDArray(DataBuffer buffer) {
        this.data = buffer;
        if (buffer.length() >= Integer.MAX_VALUE)
            throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE");
        long[] shape = {1, (int) buffer.length()};
        long[] stride = Nd4j.getStrides(shape);
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride,  1, Nd4j.order(), buffer.dataType(), false));
        init(shape, stride);
    }

    /**
     *
     * @param buffer
     * @param shape
     * @param stride
     * @param offset
     * @param ordering
     */
    public BaseNDArray(DataBuffer buffer, int[] shape, int[] stride, long offset, char ordering) {
        Shape.assertValidOrder(ordering);
        this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride),
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, buffer.dataType(), false));
        init(shape, stride);
        // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));

    }

    public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering) {
        this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering);
    }

    public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, char ordering) {
        Shape.assertValidOrder(ordering);
        this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, ews, ordering, buffer.dataType(), false ));
        init(shape, stride);
        // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));
    }

    public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset,  char ordering, DataType dataType) {
        this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType);
    }

    public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, char ordering, DataType dataType) {
        this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, ews, ordering, dataType, false));
        init(shape, stride);
    }

    public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type) {
        this.data = buffer;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride,
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false));
        init(shape, stride);
        // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));
    }

    public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type, MemoryWorkspace workspace) {
        this.data = buffer;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride,
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false));
        init(shape, stride);
        // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));
    }

    public BaseNDArray(DataBuffer buffer,  DataType dataType, long[] shape, long[] stride, long offset, char ordering) {
        this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride,
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType, false));
        init(shape, stride);
        // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));
    }

    /**
     * Initialize the ndarray as a matrix
     * with the given data (indices preserved)
     * @param data
     */
    public BaseNDArray(double[][] data) {
        this(data, Nd4j.order());
    }

    /**
     *
     * @param data
     * @param ordering
     */
    public BaseNDArray(double[][] data, char ordering) {
        this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)),
                new int[] {data.length, data[0].length},
                Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering);

        int c = columns();
        for (int r = 0; r < rows(); r++) {
            Preconditions.checkState(data[r].length == c, "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c );
        }
    }


    /**
     * Create with the specified shape and buffer
     *
     * @param shape  the shape
     * @param buffer the buffer
     */
    public BaseNDArray(int[] shape, DataBuffer buffer) {
        this.data = buffer;
        init(shape, Nd4j.getStrides(shape));
    }

    /**
     * Create this ndarray with the given data and shape and 0 offset
     *
     * @param data  the data to use
     * @param shape the shape of the ndarray
     */
    public BaseNDArray(float[] data, int[] shape, char ordering) {
        this(data, shape, 0, ordering);
    }

    /**
     * @param data     the data to use
     * @param shape    the shape of the ndarray
     * @param offset   the desired offset
     * @param ordering the ordering of the ndarray
     */
    public BaseNDArray(float[] data, int[] shape, long offset, char ordering) {
        this(data, shape, Nd4j.getStrides(shape, ordering), offset);
    }

    public BaseNDArray(double[] data, long[] shape, long offset, char ordering) {
        this(data, shape, Nd4j.getStrides(shape, ordering), offset);
    }

    public BaseNDArray(float[] data, long[] shape, long offset, char ordering) {
        this(data, shape, Nd4j.getStrides(shape, ordering), offset);
    }


    /**
     * Construct an ndarray of the specified shape
     * with an empty data array
     *
     * @param shape    the shape of the ndarray
     * @param stride   the stride of the ndarray
     * @param offset   the desired offset
     * @param ordering the ordering of the ndarray
     */
    public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) {
        this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering);
    }

    public BaseNDArray(long[] shape, long[] stride, long offset, char ordering) {
        this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering);
    }

    /**
     * Construct an ndarray of the specified shape.
     *
     * @param shape    the shape of the ndarray
     * @param stride   the stride of the ndarray
     * @param offset   the desired offset
     * @param ordering the ordering of the ndarray
     * @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't.
     */
    public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) {
        this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering);
    }

    public BaseNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) {
        this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering);
    }

    public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize) {
        this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), type, shape, stride, offset, ordering);
    }

    public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize, MemoryWorkspace workspace) {
        this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize, workspace), type, shape, stride, offset, ordering);
    }

    public BaseNDArray(DataType type, long[] shape, long[] paddings, long[] paddingOffsets, char ordering, MemoryWorkspace workspace) {

        //calculate strides with paddings
        int rank = shape.length;
        if(paddings == null || paddings.length != rank ) throw new IllegalArgumentException("The length of Padding should be equal to the length of Shape");
        long [] paddedShape = new long[rank];
        boolean empty = false; 
        boolean zeroOffset = paddingOffsets == null || paddingOffsets.length == 0; 
        boolean paddingOffsetsInvalid = paddingOffsets != null && paddingOffsets.length != rank ;
        long ews = 1;
        if(!paddingOffsetsInvalid){
            for(int i=0; ipaddings[i]){
                    paddingOffsetsInvalid = true;
                    break;
                }
            }   
        }
        if(!zeroOffset && paddingOffsetsInvalid) throw new IllegalArgumentException("If PaddingOffsets is not empty or zero length then its length should match the length of Paddings and also its elements should not be greater");

        long[] paddedStride = ordering == 'c' ? ArrayUtil.calcStrides(paddedShape,1): ArrayUtil.calcStridesFortran(paddedShape,1);
        long paddedAllocSize = ordering == 'c' ? paddedShape[0] * paddedStride[0] : paddedShape[rank-1] * paddedStride[rank-1];

        long offset = (empty || ews == 1 || zeroOffset) ? 0 :  ArrayUtil.calcOffset(paddedShape, paddingOffsets, paddedStride);
        DataBuffer buffer = Nd4j.createBuffer(type, paddedAllocSize, false, workspace);
        this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, paddedAllocSize - offset) : buffer ;
        long extras  = ArrayOptionsHelper.setOptionBit(0, type);
        if(empty) extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.ATYPE_EMPTY_BIT);
        else if(ews!=1) extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.HAS_PADDED_BUFFER);
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, paddedStride, ews, ordering, extras));
    }

    /**
     * Create the ndarray with
     * the specified shape and stride and an offset of 0
     *
     * @param shape    the shape of the ndarray
     * @param stride   the stride of the ndarray
     * @param ordering the ordering of the ndarray
     */
    public BaseNDArray(int[] shape, int[] stride, char ordering) {
        this(shape, stride, 0, ordering);
    }


    /**
     *
     * @param shape
     * @param offset
     * @param ordering
     */
    public BaseNDArray(int[] shape, long offset, char ordering) {
        this(shape, Nd4j.getStrides(shape, ordering), offset, ordering);
    }

    public BaseNDArray(long[] shape, long offset, char ordering) {
        this(shape, Nd4j.getStrides(shape, ordering), offset, ordering);
    }


    /**
     * Create an ndarray
     * with the given shape
     * @param shape
     */
    public BaseNDArray(int[] shape) {
        this(shape, 0, Nd4j.order());
    }

    public BaseNDArray(long[] shape) {
        this(shape, 0, Nd4j.order());
    }


    /**
     * Creates a new n times m DoubleMatrix.
     *
     * @param newRows    the number of rows (n) of the new matrix.
     * @param newColumns the number of columns (m) of the new matrix.
     */
    public BaseNDArray(int newRows, int newColumns, char ordering) {
        Shape.assertValidOrder(ordering);
        this.data = Nd4j.createBuffer((long) newRows * newColumns);
        val shape = new long[] {newRows, newColumns};
        val stride = Nd4j.getStrides(shape, ordering);
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride,
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false));
        init(shape, stride);
    }

    public BaseNDArray(long newRows, long newColumns, char ordering) {
        Shape.assertValidOrder(ordering);
        this.data = Nd4j.createBuffer((long) newRows * newColumns);
        long[] shape = new long[] {newRows, newColumns};
        long[] stride = Nd4j.getStrides(shape, ordering);
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride,
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false));
        init(shape, stride);
    }


    /**
     * Create an ndarray from the specified slices.
     * This will go through and merge all of the
     * data from each slice in to one ndarray
     * which will then take the specified shape
     *
     * @param slices the slices to merge
     * @param shape  the shape of the ndarray
     */
    public BaseNDArray(List slices, int[] shape, char ordering) {
        this(slices, shape, Nd4j.getStrides(shape, ordering), ordering);
    }

    public BaseNDArray(List slices, long[] shape, char ordering) {
        this(slices, shape, Nd4j.getStrides(shape, ordering), ordering);
    }


    /**
     * Create an ndarray from the specified slices.
     * This will go through and merge all of the
     * data from each slice in to one ndarray
     * which will then take the specified shape
     *
     * @param slices the slices to merge
     * @param shape  the shape of the ndarray
     */
    public BaseNDArray(List slices, int[] shape, int[] stride, char ordering) {
        Shape.assertValidOrder(ordering);
        DataBuffer ret = slices.get(0).data().dataType() == (DataType.FLOAT)
                ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)])
                : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]);
        this.data = ret;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride),
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), false));
        init(shape, stride);
        //    Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));

        if (slices.get(0).isScalar()) {
            for (int i = 0; i < length(); i++) {
                putScalar(i, slices.get(i).getDouble(0));
            }
        } else {
            for (int i = 0; i < slices(); i++) {
                putSlice(i, slices.get(i));
            }
        }
    }


    public BaseNDArray(List slices, long[] shape, long[] stride, char ordering) {
        DataBuffer ret = Nd4j.createBuffer(slices.get(0).dataType(), Shape.lengthOf(shape), false); /*slices.get(0).data().dataType() == (DataType.FLOAT)
                ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)])
                : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]);
                */
        this.data = ret;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride,
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), false));
        init(shape, stride);
        //    Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f'));

        if (slices.get(0).isScalar()) {
            for (int i = 0; i < length(); i++) {
                putScalar(i, slices.get(i).getDouble(0));
            }
        } else {
            for (int i = 0; i < slices(); i++) {
                putSlice(i, slices.get(i));
            }
        }
    }

    /**
     *
     * @param data
     * @param shape
     * @param stride
     * @param ordering
     */
    public BaseNDArray(float[] data, int[] shape, int[] stride, char ordering) {
        this(data, shape, stride, 0, ordering);
    }

    /**
     *
     * @param data
     * @param shape
     * @param stride
     * @param offset
     * @param ordering
     */
    public BaseNDArray(float[] data, int[] shape, int[] stride, long offset, char ordering) {
        Shape.assertValidOrder(ordering);
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride),
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data != null && data.length > 0 ? false : true));
        if (data != null && data.length > 0) {

            val perfD = PerformanceTracker.getInstance().helperStartTransaction();

            this.data = internalCreateBuffer(data, offset);

            PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, data.length * Nd4j.sizeOfDataType(DataType.FLOAT), MemcpyDirection.HOST_TO_HOST);

            if (offset >= data.length)
                throw new IllegalArgumentException("invalid offset: must be < data.length");
        }

        init(shape, stride);
    }

    public BaseNDArray(float[] data, long[] shape, long[] stride, long offset, char ordering) {
        Shape.assertValidOrder(ordering);
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride,
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data != null && data.length > 0 ? false : true));
        if (data != null && data.length > 0) {
            this.data = Nd4j.createTypedBuffer(data, DataType.FLOAT);
            if (offset >= data.length)
                throw new IllegalArgumentException("invalid offset: must be < data.length");
        }

        init(shape, stride);
    }

    public BaseNDArray(double[] data, long[] shape, long[] stride, long offset, char ordering) {
        Shape.assertValidOrder(ordering);
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride,
                Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.DOUBLE, data != null && data.length > 0 ? false : true));
        if (data != null && data.length > 0) {
            this.data = Nd4j.createBuffer(data, offset);
            if (offset >= data.length)
                throw new IllegalArgumentException("invalid offset: must be < data.length");
        }

        init(shape, stride);
    }

    /**
     *
     * @param data
     * @param shape
     * @param stride
     * @param offset
     */
    public BaseNDArray(DataBuffer data, int[] shape, int[] stride, long offset) {
        this.data = Nd4j.createBuffer(data, offset, ArrayUtil.prodLong(shape));
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride),
                Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f'), Nd4j.order(), data.dataType(), false));
        init(shape, stride);
        //  Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f'));


    }

    /**
     *
     * @param data
     * @param shape
     * @param strides
     */
    public BaseNDArray(int[] data, int[] shape, int[] strides) {
        this(internalCreateBuffer(data), shape, strides);
    }

    /**
     *
     * @param data
     * @param shape
     */
    public BaseNDArray(DataBuffer data, int[] shape) {
        this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order());
    }

    public BaseNDArray(DataBuffer data, long[] shape) {
        this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order());
    }


    /**
     *
     * @param buffer
     * @param shape
     * @param offset
     */
    public BaseNDArray(DataBuffer buffer, int[] shape, long offset) {
        this(Nd4j.createBuffer(buffer, offset, ArrayUtil.prodLong(shape)), shape, Nd4j.getStrides(shape), offset,
                Nd4j.order());
    }

    /**
     *
     * @param buffer
     * @param shape
     * @param ordering
     */
    public BaseNDArray(DataBuffer buffer, int[] shape, char ordering) {
        this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering);
    }

    public BaseNDArray(DataBuffer buffer, long[] shape, char ordering) {
        this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering);
    }

    /**
     *
     * @param data
     * @param shape
     * @param ordering
     */
    public BaseNDArray(double[] data, int[] shape, char ordering) {
        this(Nd4j.createBuffer(data), shape, ordering);
    }

    public BaseNDArray(double[] data, long[] shape, char ordering) {
        this(Nd4j.createBuffer(data), shape, ordering);
    }

    public BaseNDArray(float[] data, long[] shape, char ordering) {
        this(Nd4j.createBuffer(data), shape, ordering);
    }

    /**
     *
     * @param data
     * @param shape
     * @param stride
     * @param offset
     * @param ordering
     */
    public BaseNDArray(double[] data, int[] shape, int[] stride, long offset, char ordering) {
        this(internalCreateBuffer(data, offset), shape, stride, offset, ordering);
    }

    /**
     *
     * @param data
     * @param order
     */
    public BaseNDArray(float[] data, char order) {
        this(internalCreateBuffer(data), order);
    }

    protected static DataBuffer internalCreateBuffer(float[] data) {
        val perfX = PerformanceTracker.getInstance().helperStartTransaction();

        val buffer = Nd4j.createBuffer(data);
        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST);

        return buffer;
    }

    protected static DataBuffer internalCreateBuffer(double[] data) {
        val perfX = PerformanceTracker.getInstance().helperStartTransaction();

        val buffer = Nd4j.createBuffer(data);
        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST);

        return buffer;
    }

    protected static DataBuffer internalCreateBuffer(int[] data) {
        val perfX = PerformanceTracker.getInstance().helperStartTransaction();

        val buffer = Nd4j.createBuffer(data);
        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST);

        return buffer;
    }

    protected static DataBuffer internalCreateBuffer(float[] data, long offset) {
        val perfX = PerformanceTracker.getInstance().helperStartTransaction();

        val buffer = Nd4j.createBuffer(data, offset);
        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST);

        return buffer;
    }

    protected static DataBuffer internalCreateBuffer(double[] data, long offset) {
        val perfX = PerformanceTracker.getInstance().helperStartTransaction();

        val buffer = Nd4j.createBuffer(data, offset);
        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST);

        return buffer;
    }

    /**
     *
     * @param floatBuffer
     * @param order
     */
    public BaseNDArray(DataBuffer floatBuffer, char order) {
        this(floatBuffer, new int[] {(int) floatBuffer.length()},
                Nd4j.getStrides(new int[] {(int) floatBuffer.length()}, order), 0, order);
        Shape.assertValidOrder(order);
        if (floatBuffer.length() >= Integer.MAX_VALUE)
            throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE");
    }

    /**
     *
     * @param buffer
     * @param shape
     * @param strides
     */
    public BaseNDArray(DataBuffer buffer, int[] shape, int[] strides) {
        this(buffer, shape, strides, 0, Nd4j.order());
    }


    /**
     * Create this ndarray with the given data and shape and 0 offset
     *
     * @param data  the data to use
     * @param shape the shape of the ndarray
     */
    public BaseNDArray(float[] data, int[] shape) {
        this(data, shape, 0);
    }


    /**
     *
     * @param data
     * @param shape
     * @param offset
     */
    public BaseNDArray(float[] data, int[] shape, long offset) {
        this(data, shape, offset, Nd4j.order());

    }

    /**
     * Construct an ndarray of the specified shape
     * with an empty data array
     *
     * @param shape  the shape of the ndarray
     * @param stride the stride of the ndarray
     * @param offset the desired offset
     */
    public BaseNDArray(int[] shape, int[] stride, long offset) {
        this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order());
    }

    public BaseNDArray(long[] shape, long[] stride, long offset) {
        this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order());
    }

    /**
     * Create the ndarray with
     * the specified shape and stride and an offset of 0
     *
     * @param shape  the shape of the ndarray
     * @param stride the stride of the ndarray
     */
    public BaseNDArray(int[] shape, int[] stride) {
        this(shape, stride, 0);
    }

    /**
     *
     * @param shape
     * @param offset
     */
    public BaseNDArray(int[] shape, long offset) {
        this(shape, Nd4j.getStrides(shape), offset);
    }

    /**
     *
     * @param shape
     * @param ordering
     */
    public BaseNDArray(int[] shape, char ordering) {
        this(shape, 0, ordering);
    }


    /**
     * Creates a new n times m DoubleMatrix.
     *
     * @param newRows    the number of rows (n) of the new matrix.
     * @param newColumns the number of columns (m) of the new matrix.
     */
    public BaseNDArray(int newRows, int newColumns) {
        this(newRows, newColumns, Nd4j.order());
    }

    public BaseNDArray(long newRows, long newColumns) {
        this(newRows, newColumns, Nd4j.order());
    }


    /**
     * Create an ndarray from the specified slices.
     * This will go through and merge all of the
     * data from each slice in to one ndarray
     * which will then take the specified shape
     *
     * @param slices the slices to merge
     * @param shape  the shape of the ndarray
     */
    public BaseNDArray(List slices, int[] shape) {
        this(slices, shape, Nd4j.order());
    }

    public BaseNDArray(List slices, long[] shape) {
        this(slices, shape, Nd4j.order());
    }

    /**
     * Create an ndarray from the specified slices.
     * This will go through and merge all of the
     * data from each slice in to one ndarray
     * which will then take the specified shape
     *
     * @param slices the slices to merge
     * @param shape  the shape of the ndarray
     */
    public BaseNDArray(List slices, int[] shape, int[] stride) {
        this(slices, shape, stride, Nd4j.order());
    }

    public BaseNDArray(List slices, long[] shape, long[] stride) {
        this(slices, shape, stride, Nd4j.order());
    }

    /**
     *
     * @param data
     * @param shape
     * @param stride
     */
    public BaseNDArray(float[] data, int[] shape, int[] stride) {
        this(data, shape, stride, Nd4j.order());
    }


    /**
     *
     * @param data
     * @param shape
     * @param stride
     * @param offset
     */
    public BaseNDArray(float[] data, int[] shape, int[] stride, long offset) {
        this(data, shape, stride, offset, Nd4j.order());
    }

    public BaseNDArray(double[] data, long[] shape, long[] stride, long offset) {
        this(data, shape, stride, offset, Nd4j.order());
    }

    public BaseNDArray(float[] data, long[] shape, long[] stride, long offset) {
        this(data, shape, stride, offset, Nd4j.order());
    }

    /**
     *
     * @param data
     */
    public BaseNDArray(float[] data) {
        this(Nd4j.createBuffer(data));
    }


    /**
     * Initialize the ndarray
     * with the given data
     * @param data
     */
    public BaseNDArray(float[][] data) {
        this(data, Nd4j.order());
    }

    /**
     *
     * @param data
     * @param ordering
     */
    public BaseNDArray(float[][] data, char ordering) {
        this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)),
                new int[] {data.length, data[0].length},
                Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering);

        int c = columns();
        for (int r = 0; r < rows(); r++) {
            Preconditions.checkState(data[r].length == c, "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c );
        }
    }



    /**
     * Constructor for stride and offset
     *
     * @param buffer
     * @param shape
     * @param offset
     * @param ordering
     */
    public BaseNDArray(DataBuffer buffer, int[] shape, long offset, char ordering) {
        this(buffer, shape, Nd4j.getStrides(shape, ordering), offset, ordering);
    }

    public BaseNDArray(double[] data, int[] shape, int[] stride, long offset) {
        this(data, shape, stride, offset, Nd4j.order());
    }


    /**
     * Returns whether the ndarray is valid or not
     * @return true if the ndarray is valid
     * false otherwise
     */
    @Deprecated
    public boolean isValid() {
        try {
            linearIndex(length() - 1);
        } catch (Exception e) {
            return false;
        }
        return true;
    }

    protected INDArray create(DataBuffer data, int[] shape, long offset) {
        return Nd4j.create(data, shape, offset);
    }

    @Override
    public int elementWiseStride() {
        return Shape.elementWiseStride(shapeInfoDataBuffer());
    }

    @Override
    public long tensorsAlongDimension(int... dimension) {
        if (dimension == null || dimension.length == 0)
            throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)");
        if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)
            return 1;
        for (int i = 0; i < dimension.length; i++)
            if (dimension[i] < 0)
                dimension[i] += rank();
        long[] tensorShape = ArrayUtil.keep(shape(), dimension);
        long len = ArrayUtil.prodLong(tensorShape);
        if (len == 0)
            throw new IllegalStateException("Illegal length found after removing index");
        long length = length();
        if (length / len >= Integer.MAX_VALUE)
            throw new IllegalArgumentException("Tensors along dimension can not be >= Integer.MAX_VALUE");
        return length / len;
    }

    @Override
    public INDArray tensorAlongDimension(long index, int... dimension) {
        if (dimension == null || dimension.length == 0)
            throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)");

        Preconditions.checkArgument(!this.isEmpty(), "tensorAlongDimension(...) can't be used on empty tensors");

        if (dimension.length >= rank()  || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)
            return this;
        for (int i = 0; i < dimension.length; i++)
            if (dimension[i] < 0)
                dimension[i] += rank();

        //dedup
        if (dimension.length > 1)
            dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension))));

        if (dimension.length > 1) {
            Arrays.sort(dimension);
        }

        long tads = tensorsAlongDimension(dimension);
        if (index >= tads)
            throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads);


        if (dimension.length == 1) {
            if (dimension[0] == 0 && isColumnVector()) {
                return this.transpose();
            } else if (dimension[0] == 1 && isRowVector()) {
                return this;
            }
        }

        Pair tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
        DataBuffer shapeInfo = tadInfo.getFirst();
        val jShapeInfo = shapeInfo.asLong();
        val shape = Shape.shape(jShapeInfo);
        val stride = Shape.stride(jShapeInfo);
        long offset = offset() + tadInfo.getSecond().getLong(index);
        val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2);
        char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3);
        val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder);
        return toTad;
    }

    private void setShapeInformation(Pair shapeInfo) {
        this.shapeInformation = shapeInfo.getFirst();
        this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond());
    }


    private INDArray doTad(int index, int... dimension) {
        if (dimension == null || dimension.length == 0)
            throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)");

        if (dimension.length >= rank())
            return this;
        for (int i = 0; i < dimension.length; i++)
            if (dimension[i] < 0)
                dimension[i] += rank();

        if (dimension.length > 1)
            Arrays.sort(dimension);

        long tads = tensorsAlongDimension(dimension);
        if (index >= tads)
            throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads);


        if (dimension.length == 1) {
            if (dimension[0] == 0 && isColumnVector()) {
                return this.transpose();
            } else if (dimension[0] == 1 && isRowVector()) {
                return this;
            }
        }


        long[] tensorShape = ArrayUtil.keep(shape(), dimension);
        int[] reverseDimensions = ArrayUtil.reverseCopy(dimension);
        int[] remove = ArrayUtil.removeIndex(ArrayUtil.range(0, rank()), dimension);
        int[] newPermuteDims = Ints.concat(remove, reverseDimensions);
        int[] finalPermuteDims = tadFinalPermuteDimensions[dimension.length];

        INDArray permuted = permute(newPermuteDims);
        long sliceIdx = NDArrayMath.sliceOffsetForTensor(index, permuted, tensorShape);

        INDArray ret2 = permuted.slice(sliceIdx);
        if (dimension.length == tensorShape.length && ArrayUtil.prodLong(tensorShape) == ret2.length()) {
            if (dimension.length == 1 && ret2.isRowVector())
                return ret2;
            if (finalPermuteDims.length != ret2.rank()) {
                finalPermuteDims = new int[ret2.rank()];
                int count = 0;
                for (int i = finalPermuteDims.length - 1; i >= 0; i--)
                    finalPermuteDims[count++] = i;
            }
            return ret2.permutei(finalPermuteDims);
        }


        int length = ArrayUtil.prod(tensorShape);
        int tensorLength = ArrayUtil.prod(tensorShape);
        long offset = index * tensorLength / NDArrayMath.lengthPerSlice(ret2);

        if (sliceIdx == 0 && length == NDArrayMath.lengthPerSlice(ret2)) {
            if (offset > Integer.MAX_VALUE)
                throw new ND4JArraySizeException();
            ret2 = ret2.slice((int) offset);
            if (dimension.length == 1 && ret2.isRowVectorOrScalar())
                return ret2;
            return ret2.permutei(finalPermuteDims);
        }

        else if (length == NDArrayMath.lengthPerSlice(ret2)) {
            offset -= ret2.slices() * (offset / ret2.slices());

            if (offset > Integer.MAX_VALUE)
                throw new ND4JArraySizeException();
            ret2 = ret2.slice((int) offset);
            if (dimension.length == 1 && ret2.isRowVectorOrScalar())
                return ret2;
            return ret2.permutei(finalPermuteDims);
        }

        while (ret2.length() > length) {
            sliceIdx = NDArrayMath.sliceOffsetForTensor(index, ret2, tensorShape);
            sliceIdx -= ret2.slices() * (sliceIdx / ret2.slices());
            ret2 = ret2.slice(sliceIdx);
        }

        if (dimension.length == 1 && ret2.isRowVectorOrScalar())
            return ret2;

        return ret2.permutei(finalPermuteDims);
    }

    @Override
    public long vectorsAlongDimension(int dimension) {
        if (dimension == 0 && isVector() || isRowVectorOrScalar())
            return 1;
        if (size(dimension) == 1 && !isVector()) {
            for (int i = dimension; i < rank(); i++) {
                if (size(i) != 1)
                    return vectorsAlongDimension(i);
            }

            return length();

        } else if (size(0) == 1 && !isVectorOrScalar()) {
            int realDimension = rank() - getLeadingOnes();
            long length = length();
            if (length / size(realDimension) >= Integer.MAX_VALUE)
                throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE");
            return length / size(realDimension);
        }

        long length = length();

        if (dimension >= jvmShapeInfo.rank) {
            if (length / size(jvmShapeInfo.rank - 1) >= Integer.MAX_VALUE)
                throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE");
            return (int) (length / size(jvmShapeInfo.rank - 1));
        }
        if (length / size(dimension) >= Integer.MAX_VALUE)
            throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE");
        return length / size(dimension);
    }

    @Override
    public INDArray vectorAlongDimension(int index, int dimension) {
        if (dimension < 0) {
            dimension = jvmShapeInfo.getRank() + dimension;
        }

        //return the whole thing
        if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2
                || rank() > 2 && dimension == 0 && size(dimension) == 1) {
            return this;
        }

        return tensorAlongDimension(index, dimension);
    }

    @Override
    public void setOrder(char order) {
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty()));
    }

    @Override
    public void setShapeAndStride(int[] shape, int[] stride) {
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride),  0, ordering(), this.dataType(), false));
    }

    @Override
    public INDArray cumsumi(int dimension) {
        validateNumericalArray("cumsumi", true);

        if(isScalar() || isEmpty())
            return this;

        if (isVector()) {
            double s = 0.0;
            for (int i = 0; i < length(); i++) {
                s += getDouble(i);
                putScalar(i, s);
            }
        } else if (dimension == Integer.MAX_VALUE) {
            INDArray flattened = ravel();
            double prevVal = flattened.getDouble(0);
            for (int i = 1; i < flattened.length(); i++) {
                double d = prevVal + flattened.getDouble(i);
                flattened.putScalar(i, d);
                prevVal = d;
            }

            return flattened;
        } else {
            for (int i = 0; i < vectorsAlongDimension(dimension); i++) {
                INDArray vec = vectorAlongDimension(i, dimension);
                vec.cumsumi(0);

            }
        }

        return this;
    }

    @Override
    public Number normmaxNumber() {
        return normmax(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number norm2Number() {
        return norm2(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number norm1Number() {
        return norm1(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number stdNumber() {
        return std(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number prodNumber() {
        if(isScalar())
            return getNumber(0);
        return prod(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number meanNumber() {
        validateNumericalArray("meanNumber", false);
        if(isScalar())
            return getNumber(0);
        return mean(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number ameanNumber() {
        return amean(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number varNumber() {
        return var(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number maxNumber() {
        if(isScalar())
            return getNumber(0);
        return max(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number amaxNumber() {
        return amax(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number minNumber() {
        if(isScalar())
            return getNumber(0);
        return min(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number aminNumber() {
        return amin(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number scan(Condition condition) {
        MatchCondition op = new MatchCondition(this, condition);
        return Nd4j.getExecutioner().exec(op).getDouble(0);
    }

    @Override
    public Number sumNumber() {
        validateNumericalArray("sum", false);
        if(isScalar())
            return getNumber(0);
        val scalar = sum(Integer.MAX_VALUE);
        Nd4j.getExecutioner().commit();
        return scalar.getDouble(0);
    }

    @Override
    public Number entropyNumber() {
        return entropy(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number shannonEntropyNumber() {
        return shannonEntropy(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public Number logEntropyNumber() {
        return logEntropy(Integer.MAX_VALUE).getDouble(0);
    }

    @Override
    public INDArray cumsum(int dimension) {
        validateNumericalArray("cumsum", true);
        return dup().cumsumi(dimension);
    }

    @Override
    public INDArray assign(final INDArray arr) {
        Preconditions.checkState((this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()),
                "Cannot assign arrays: arrays must both be scalars, both vectors, or shapes must be equal other than size 1 dimensions. Attempting to do x.assign(y)" +
                        " with x.shape=%ndShape and y.shape=%ndShape", this, arr );

        Preconditions.checkArgument(this.length() == arr.length(), "Length of both arrays must be equal");

        Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this));
        return this;
    }

    @Override
    public INDArray putScalar(long i, double value) {
        Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" +
                " - only putScalar with values 0 or 1 is allowed on boolean arrays", value);
        if (i < 0)
            i += rank();

        // TODO: i'm not sure that rank == 1 has fair shortcut here
        if (isScalar()) {
            autoProcessScalarCall();
            data.put(i, value);
            return this;
        } else if (rank() == 1) {
            data.put(i * stride(0), value);
            return this;
        }

        // we cant raise rank here, if original rank is 1
        if (isRowVector() && rank() == 2) {
            return putScalar(0, i, value);
        } else if (isColumnVector() && rank() == 2) {
            return putScalar(i, 0, value);
        }
        long[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
        return putScalar(indexes, value);
    }

    @Override
    public INDArray putScalar(long i, float value) {
        return putScalar(i, (double) value);
    }

    @Override
    public INDArray putScalar(long i, int value) {
        return putScalar(i, (double) value);
    }

    @Override
    public INDArray putScalar(int[] indexes, double value) {
        Nd4j.getCompressor().autoDecompress(this);

        Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" +
                " - only putScalar with values 0 or 1 is allowed on boolean arrays", value);

        for (int i = 0; i < indexes.length; i++) {
            if (indexes[i] < 0)
                indexes[i] += this.size(i);
        }

        if (indexes.length == 1) {
            return putScalar(indexes[0], value);
        } else if (indexes.length == 2) {
            return putScalar(indexes[0], indexes[1], value);
        } else if (indexes.length == 3) {
            return putScalar(indexes[0], indexes[1], indexes[2], value);
        } else if (indexes.length == 4) {
            return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value);
        } else {
            autoProcessScalarCall();
            long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes);
            data.put(offset, value);
        }
        return this;
    }

    @Override
    public INDArray putScalar(long[] indexes, double value) {
        Nd4j.getCompressor().autoDecompress(this);

        Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" +
                " - only putScalar with values 0 or 1 is allowed on boolean arrays", value);


        for (int i = 0; i < indexes.length; i++) {
            if (indexes[i] < 0)
                indexes[i] += size(i);
        }

        if (indexes.length == 1) {
            return putScalar(indexes[0], value);
        } else if (indexes.length == 2) {
            return putScalar(indexes[0], indexes[1], value);
        } else if (indexes.length == 3) {
            return putScalar(indexes[0], indexes[1], indexes[2], value);
        } else if (indexes.length == 4) {
            return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value);
        } else {
            autoProcessScalarCall();
            long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes);
            data.put(offset, value);
        }
        return this;
    }

    @Override
    public INDArray putScalar(long[] indexes, float value) {
        return putScalar(indexes, (double) value);
    }

    @Override
    public INDArray putScalar(long row, long col, double value) {
        Nd4j.getCompressor().autoDecompress(this);
        autoProcessScalarCall();

        Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" +
                " - only putScalar with values 0 or 1 is allowed on boolean arrays", value);

        if (rank() > 2)
            throw new IllegalStateException("Cannot use putScalar(int,int,double) on a rank " + rank() + " INDArray");
        long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, row, col);
        data.put(offset, value);
        return this;
    }

    @Override
    public INDArray putScalar(long dim0, long dim1, long dim2, double value) {
        Nd4j.getCompressor().autoDecompress(this);
        autoProcessScalarCall();

        Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" +
                " - only putScalar with values 0 or 1 is allowed on boolean arrays", value);

        if (rank() != 3)
            throw new IllegalStateException(
                    "Cannot use putScalar(int,int,int,double) on a rank " + rank() + " INDArray");
        long offset = 0; // Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2);
        long size_0 = jvmShapeInfo.javaShapeInformation[1];
        long size_1 = jvmShapeInfo.javaShapeInformation[1 + 1];
        long size_2 = jvmShapeInfo.javaShapeInformation[1 + 2];

        if (size_0 != 1)
            offset += dim0 * jvmShapeInfo.javaShapeInformation[1 + 0 + 3];
        if (size_1 != 1)
            offset += dim1 * jvmShapeInfo.javaShapeInformation[1 + 1 + 3];
        if (size_2 != 1)
            offset += dim2 * jvmShapeInfo.javaShapeInformation[1 + 2 + 3];

        data.put(offset, value);
        return this;
    }

    @Override
    public INDArray putScalar(long dim0, long dim1, long dim2, long dim3, double value) {
        Nd4j.getCompressor().autoDecompress(this);
        autoProcessScalarCall();
        Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" +
                " - only putScalar with values 0 or 1 is allowed on boolean arrays", value);

        if (rank() != 4)
            throw new IllegalStateException(
                    "Cannot use putScalar(int,int,int,int,double) on a rank " + rank() + " INDArray");
        long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, dim0, dim1, dim2, dim3);
        data.put(offset, value);
        return this;
    }

    @Override
    public INDArray putScalar(int[] indexes, float value) {
        return putScalar(indexes, (double) value);
    }

    @Override
    public INDArray putScalar(int[] indexes, int value) {
        return putScalar(indexes, (double) value);
    }

    @Override
    public INDArray putScalar(long[] indexes, int value) {
        return putScalar(indexes, (double) value);
    }

    @Override
    public INDArray eps(Number other) {
        validateNumericalArray("eps", true);
        return Nd4j.getExecutioner().exec(new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
    }

    @Override
    public INDArray eps(INDArray other) {
        validateNumericalArray("eps", true);
        return Nd4j.getExecutioner().exec(new Eps(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())));
    }

    @Override
    public INDArray lt(Number other) {
        validateNumericalArray("less than (lt)", false);
        return Nd4j.getExecutioner().exec(new ScalarLessThan(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
    }

    @Override
    public INDArray lte(Number other) {
        validateNumericalArray("less than or equals (lte)", false);
        return Nd4j.getExecutioner().exec(new ScalarLessThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
    }

    @Override
    public INDArray eq(Number other) {
        Preconditions.checkArgument(dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, "Scalar equality on boolean arrays can only be applied with values 0 or 1: got value %s",other);
        return Nd4j.getExecutioner().exec(new ScalarEquals(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
    }

    @Override
    public INDArray gt(Number other) {
        validateNumericalArray("greater than (gt)", false);
        return Nd4j.getExecutioner().exec(new ScalarGreaterThan(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
    }

    @Override
    public INDArray gte(Number other) {
        validateNumericalArray("greater than or equals (gte)", false);
        return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
    }

    @Override
    public INDArray lt(INDArray other) {
        validateNumericalArray("less than (lt)", false);
        if (Shape.shapeEquals(this.shape(), other.shape())) {
            return Nd4j.getExecutioner().exec(new LessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
        } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
            return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
        } else
            throw new IllegalArgumentException("Shapes must be broadcastable");
    }

    @Override
    public INDArray neq(Number other) {
        Preconditions.checkArgument(dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, "Scalar non-equality on boolean arrays can only be applied with values 0 or 1: got value %s",other);
        Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array");
        return Nd4j.getExecutioner().exec(new ScalarNotEquals(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other));
    }

    @Override
    public INDArray neq(INDArray other) {
        Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array");
        return Nd4j.getExecutioner().exec(new NotEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
    }

    @Override
    public INDArray eq(INDArray other) {
        if (Shape.shapeEquals(this.shape(), other.shape())) {
            return Nd4j.getExecutioner().exec(new EqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
        } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
            return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
        } else
            throw new IllegalArgumentException("Shapes must be broadcastable");
    }

    @Override
    public INDArray gt(INDArray other) {
        validateNumericalArray("greater than (gt)", false);
        if (Shape.shapeEquals(this.shape(), other.shape())) {
            return Nd4j.getExecutioner().exec(new GreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0];
        } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
            return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0];
        } else
            throw new IllegalArgumentException("Shapes must be broadcastable");
    }

    @Override
    public INDArray isInfinite(){
        validateNumericalArray("isInfinite", true);
        if(isEmpty())
            return Nd4j.empty(DataType.BOOL);
        return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isInfinite()));
    }

    @Override
    public INDArray isNaN(){
        validateNumericalArray("isNaN", true);
        if(isEmpty())
            return Nd4j.empty(DataType.BOOL);
        return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isNan()));
    }

    @Override
    public INDArray neg() {
        validateNumericalArray("negative (neg)", true);
        if(isEmpty())
            return this;
        return Nd4j.getExecutioner().exec(new Negative(this, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering())));
    }

    @Override
    public INDArray negi() {
        validateNumericalArray("negative (negi)", true);
        if(isEmpty())
            return this;
        Nd4j.getExecutioner().exec(new Negative(this));
        return this;
    }

    @Override
    public INDArray rdiv(Number n, INDArray result) {
        return rdivi(n, result);
    }

    @Override
    public INDArray rdivi(Number n, INDArray result) {
        validateNumericalArray("rdivi", false);
        if (Double.isNaN(n.doubleValue()))
            n = Nd4j.EPS_THRESHOLD;
        Nd4j.getExecutioner().exec(new ScalarReverseDivision(this, null, result, n));
        return result;
    }

    @Override
    public INDArray rsub(Number n, INDArray result) {
        return rsubi(n, result);
    }

    @Override
    public INDArray rsubi(Number n, INDArray result) {
        validateNumericalArray("rsubi", false);

        if (Double.isNaN(n.doubleValue()))
            n = Nd4j.EPS_THRESHOLD;

        Nd4j.getExecutioner().exec(new ScalarReverseSubtraction(this, result, n));
        return result;
    }

    @Override
    public INDArray div(Number n, INDArray result) {
        return divi(n, result);
    }

    @Override
    public INDArray divi(Number n, INDArray result) {
        validateNumericalArray("divi", false);

        if (Double.isNaN(n.doubleValue()))
            n = Nd4j.EPS_THRESHOLD;
        Nd4j.getExecutioner().exec(new ScalarDivision(this, null, result, n));
        return result;
    }

    @Override
    public INDArray mul(Number n, INDArray result) {
        return muli(n, result);
    }

    @Override
    public INDArray muli(Number n, INDArray result) {
        validateNumericalArray("muli", false);
        if (Double.isNaN(n.doubleValue()))
            n = Nd4j.EPS_THRESHOLD;
        Nd4j.getExecutioner().exec(new ScalarMultiplication(this, null, result, n));
        return result;
    }

    @Override
    public INDArray sub(Number n, INDArray result) {
        return subi(n, result);
    }

    @Override
    public INDArray subi(Number n, INDArray result) {
        validateNumericalArray("subi", false);

        if (Double.isNaN(n.doubleValue()))
            n = Nd4j.EPS_THRESHOLD;

        Nd4j.getExecutioner().exec(new ScalarSubtraction(this, null, result, n));
        return result;
    }

    @Override
    public INDArray add(Number n, INDArray result) {
        return addi(n, result);
    }

    @Override
    public INDArray addi(Number n, INDArray result) {
        validateNumericalArray("addi", false);
        if (Double.isNaN(n.doubleValue()))
            n = Nd4j.EPS_THRESHOLD;

        Nd4j.getExecutioner().exec(new ScalarAdd(this, null, result, n));
        return result;
    }

    @Override
    public INDArray getScalar(long row, long column) {
        return getScalar(new long[] {row, column});
    }

    @Override
    public INDArray dup() {
        return dup(Nd4j.order());
    }

    @Override
    public INDArray dup(char order) {
        WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray");
        if (this.isCompressed() && this.ordering() == order) {
            INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer());
            ret.markAsCompressed(true);
            return ret;
        }
        if(isEmpty())
            return this;

        Nd4j.getCompressor().autoDecompress(this);

        // fixme: eventually it would be nice to have this in native code
        if (isS()) {
            val list = new ArrayList();
            for (int e = 0; e < this.length(); e++)
                list.add(this.getString(e));

            return Nd4j.create(list, this.shape(), this.ordering());
        }

        val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order);
        z.assign(this);
        return z;
    }

    @Override
    public int getInt(int... indices) {
        return (int) getDouble(indices);
    }

    @Override
    public long getLong(long index) {
        Nd4j.getCompressor().autoDecompress(this);
        Preconditions.checkState(!isEmpty(), "Unable to get value from empty array");

        if (index >= length()) {
            throw new IllegalArgumentException("Unable to get linear index " + index + ": values is greater than length (" + length() + ")");
        }

        autoProcessScalarCall();

        if (index == 0)
            return data().getLong(index);

        long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, index) : Shape.ind2sub(this, index);
        Shape.assertShapeLessThan(dimensions, shape());
        return getLong(dimensions);
    }

    @Override
    public long getLong(long... indices) {
        if(isScalar())
            return data().getLong(0);
        return Shape.getLong(this, indices);
    }

    @Override
    public double getDouble(int... indices) {
        autoProcessScalarCall();
        Nd4j.getCompressor().autoDecompress(this);
        Preconditions.checkState(!isEmpty(), "Unable to get value from empty array");

        for (int i = 0; i < indices.length; i++) {
            if (indices[i] < 0)
                indices[i] += rank();
        }
        if (indices.length == 1) {
            if (rank() == 1)
                return Shape.getDouble(this, indices[0]);
            else if (isRowVector())
                return Shape.getDouble(this, 0, indices[0]);
            else if (isColumnVector())
                return Shape.getDouble(this, indices[0], 0);
            else if ((isScalar() || length() == 1) && indices[0] == 0)
                return data().getDouble(0);
        }
        return Shape.getDouble(this, indices);
    }

    @Override
    public double getDouble(long... indices) {
        autoProcessScalarCall();
        Nd4j.getCompressor().autoDecompress(this);
        Preconditions.checkState(!isEmpty(), "Unable to get value from empty array");

        for (int i = 0; i < indices.length; i++) {
            if (indices[i] < 0)
                indices[i] += rank();
        }
        if (indices.length == 1) {
            if (rank() == 1)
                return Shape.getDouble(this, indices[0]);
            else if (isRowVector())
                return Shape.getDouble(this, 0, indices[0]);
            else if (isColumnVector())
                return Shape.getDouble(this, indices[0], 0);
            else if (isScalar() && indices[0] == 0)
                return data().getDouble(0);
            else
                throw new IllegalStateException("Indexes length must be > 1 for non vectors and scalars");
        }
        return Shape.getDouble(this, indices);
    }

    @Override
    public float getFloat(int... indices) {
        return (float) getDouble(indices);
    }

    @Override
    public float getFloat(long... indices) {
        return (float) getDouble(indices);
    }

    @Override
    public boolean isScalar() {
        if (isEmpty())
            return false;

        if (jvmShapeInfo.rank == 0) {
            return true;
        } else if (jvmShapeInfo.rank > 2) {
            return false;
        } else if (jvmShapeInfo.rank == 1) {
            return shape()[0] == 1;
        } else if (jvmShapeInfo.rank == 2) {
            return shape()[0] == 1 && shape()[1] == 1 || length() == 1;
        }

        else
            return false;

    }

    @Override
    public INDArray put(int[] indices, INDArray element) {
        Nd4j.getCompressor().autoDecompress(this);
        if (!element.isScalar())
            throw new IllegalArgumentException("Unable to insert anything but a scalar");
        if (isRowVector() && indices[0] == 0 && indices.length == 2) {
            int ix = 0;
            for (int i = 1; i < indices.length; i++)
                ix += indices[i] * stride(i);
            if (ix >= data.length())
                throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices));
            data.put(ix, element.getDouble(0));
        } else {
            int ix = 0;
            for (int i = 0; i < indices.length; i++)
                if (size(i) != 1)
                    ix += indices[i] * stride(i);
            if (ix >= data.length())
                throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices));
            data.put(ix, element.getDouble(0));
        }
        return this;
    }

    @Override
    public INDArray match(INDArray comp, Condition condition) {
        // TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition
        Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal");
        Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal");
        return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition));
    }

    @Override
    public INDArray match(Number comp, Condition condition) {
        return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp.doubleValue(), condition));
    }

    @Override
    public INDArray getWhere(INDArray comp, Condition condition) {
        return BooleanIndexing.chooseFrom(new INDArray[]{this,comp},condition);
    }

    @Override
    public INDArray getWhere(Number comp, Condition condition) {
        return BooleanIndexing.chooseFrom(new INDArray[]{this},Arrays.asList(comp.doubleValue()),Collections.emptyList(),condition);
    }

    @Override
    public INDArray putWhere(INDArray comp, INDArray put, Condition condition) {
        Nd4j.getCompressor().autoDecompress(this);
        MatchConditionTransform matchCondition = new MatchConditionTransform(this,comp,condition);
        Nd4j.getExecutioner().exec(matchCondition);
        return putWhereWithMask(matchCondition.z(),put);
    }

    @Override
    public INDArray putWhere(Number comp, INDArray put, Condition condition) {
        return putWhere(Nd4j.scalar(comp),put,condition);
    }

    @Override
    public INDArray putWhere(Number comp, Number put, Condition condition) {
        return putWhere(Nd4j.scalar(comp),Nd4j.scalar(put),condition);
    }


    @Override
    public INDArray putWhereWithMask(INDArray mask, INDArray put) {
        INDArray output = dup();
        Nd4j.getExecutioner().execAndReturn(new Where(new INDArray[]{mask,this,put},new INDArray[]{output}));
        return output;
    }

    @Override
    public INDArray putWhereWithMask(INDArray mask, Number put) {
        return putWhereWithMask(mask,Nd4j.scalar(put));
    }

    @Override
    public INDArray put(int i, int j, INDArray element) {
        return put(new int[] {i, j}, element);
    }

    @Override
    public INDArray put(int i, int j, Number element) {
        return putScalar(new int[] {i, j}, element.doubleValue());
    }

    @Override
    public INDArray putSlice(int slice, INDArray put) {
        Nd4j.getCompressor().autoDecompress(this);


        if (isScalar()) {
            Preconditions.checkState(put.isScalar(), "Invalid dimension. Can only insert a scalar in to another scalar");
            put(0, put.getScalar(0));
            return this;
        } else if (isVector()) {
            Preconditions.checkState(put.isVectorOrScalar() && put.length() == length(),
                    "Invalid dimension on insertion. Can only insert scalars/vectors into other scalar/vectors");
            if (put.isScalar())
                putScalar(slice, put.getDouble(0));
            else
                for (int i = 0; i < length(); i++)
                    putScalar(i, put.getDouble(i));
            return this;
        }

        assertSlice(put, slice);


        INDArray view = slice(slice);

        if (put.length() == 1) {
            putScalar(slice, put.getDouble(0));
        } else {
            if(!(view.isVector() && put.isVector() && view.length() == put.length()) && !view.equalShapes(put)){
                throw new IllegalStateException("Cannot put slice: array to be put (" + Arrays.toString(put.shape()) +
                        ") and slice array (" + Arrays.toString(view.shape()) + ") have different shapes");
            }
            view.assign(put);
        }
        return this;
    }

    protected void assertSlice(INDArray put, long slice) {
        Preconditions.checkArgument(slice < slices(), "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", slice, slices());
        long[] sliceShape = put.shape();
        if (Shape.isRowVectorShape(sliceShape)) {
            return;
        } else {
            long[] requiredShape = ArrayUtil.removeIndex(shape(), 0);

            //no need to compare for scalar; primarily due to shapes either being [1] or length 0
            if (put.isScalar())
                return;

            if (isVector() && put.isVector() && put.length() < length())
                return;
            //edge case for column vectors
            if (Shape.isColumnVectorShape(sliceShape))
                return;
            if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape)
                    && !Shape.isRowVectorShape(sliceShape))
                throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ",
                        Arrays.toString(sliceShape), Arrays.toString(requiredShape)));
        }
    }

    public boolean isMatrix() {
        return rank() == 2;
    }

    protected INDArray newShape(long[] newShape, char ordering) {

        return Nd4j.create(data(), newShape, stride(), 0, ordering);
    }

    protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset, char ordering) {
        return Nd4j.create(data, newShape, newStrides, offset, ordering);
    }

    protected INDArray create(DataBuffer data, long[] newShape, long[] newStrides, long offset, char ordering) {
        return Nd4j.create(data, newShape, newStrides, offset, ordering);
    }

    protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset) {
        return Nd4j.create(data, newShape, newStrides, offset);
    }

    protected INDArray create(int[] shape) {
        return Nd4j.create(shape, getStrides(shape, Nd4j.order()), 0);
    }

    protected INDArray create(int[] shape, int[] strides, long offset) {
        return Nd4j.create(shape, strides, offset);
    }

    protected int[] getStrides(int[] shape, char ordering) {
        return Nd4j.getStrides(shape, ordering);
    }

    @Override
    public double squaredDistance(INDArray other) {
        validateNumericalArray("squaredDistance", false);
        double d2 = distance2(other);
        return d2 * d2;
    }

    @Override
    public double distance2(INDArray other) {
        validateNumericalArray("distance2", false);
        Nd4j.getCompressor().autoDecompress(this);
        return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue();
    }

    @Override
    public double distance1(INDArray other) {
        validateNumericalArray("distance1", false);
        Nd4j.getCompressor().autoDecompress(this);
        return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue();
    }

    @Override
    public INDArray get(INDArray indices) {
        if(indices.rank() > 2) {
            throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
        }

        if (rank() == 1) {
            Preconditions.checkArgument(indices.rank() <= 1, "For 1D vector indices must be either scalar or vector as well");
            val ret = Nd4j.createUninitialized(this.dataType(), indices.length());
            for (int e = 0; e < indices.length(); e++) {
                val idx = indices.getLong(e);
                val value =  getDouble(idx);
                ret.putScalar(e, value);
            }

            return ret;
        } else if(indices.rows() == rank()) {
            INDArray ret = Nd4j.create(this.dataType(), indices.columns());

            for(int i = 0; i < indices.columns(); i++) {
                int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
                val v = getDouble(specifiedIndex);
                ret.putScalar(i, v);
            }

            return ret;
        }
        else {
            List arrList = new ArrayList<>();

            if(indices.isMatrix() || indices.isColumnVector()
                    || (indices.isScalar() && indices.rank() == 2)) { // we need this for compatibility with legacy code
                for(int i = 0; i < indices.rows(); i++) {
                    if(i == 0)  {
                        INDArray row = indices.getRow(i);
                        for(int j = 0; j < row.length(); j++) {
                            arrList.add(slice(row.getInt(j)));
                        }
                    }
                    else {
                        INDArray row = indices.slice(i);
                        for(int j = 0; j < row.length(); j++) {
                            INDArray put = arrList.get(j).slice(row.getInt(j));
                            put = put.reshape(Longs.concat(new long[]{1},put.shape()));
                            arrList.set(j,put);
                        }
                    }

                }
            }
            else if(indices.isRowVector()) {
                for(int i = 0; i < indices.length(); i++) {
                    INDArray add = slice(indices.getInt(i));
                    add = add.reshape(Longs.concat(new long[] {1,},add.shape()));
                    arrList.add(add);
                }
            }

            return Nd4j.concat(0,arrList.toArray(new INDArray[arrList.size()]));

        }


    }

    @Override
    public INDArray put(INDArray indices, INDArray element) {
        if(indices.rank() > 2) {
            throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
        }

        if(indices.rows() == rank()) {
            NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
            for(int i = 0; i < indices.columns(); i++) {
                int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
                putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next()));
            }
        }
        else {
            List arrList = new ArrayList<>();

            if(indices.isMatrix() || indices.isColumnVector()) {
                for(int i = 0; i < indices.rows(); i++) {
                    INDArray row = indices.getRow(i);
                    for(int j = 0; j < row.length(); j++) {
                        INDArray slice = slice(row.getInt(j));
                        Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
                        arrList.add(slice(row.getInt(j)));
                    }
                }
            }
            else if(indices.isRowVector()) {
                for(int i = 0; i < indices.length(); i++) {
                    arrList.add(slice(indices.getInt(i)));
                }
            }
        }
        return this;
    }

    @Override
    public INDArray put(INDArrayIndex[] indices, INDArray element) {
        Nd4j.getCompressor().autoDecompress(this);
        boolean isSpecifiedIndex = false;
        for(INDArrayIndex idx : indices){
            if(idx instanceof SpecifiedIndex){
                isSpecifiedIndex = true;
                break;
            }
        }

        if(!isSpecifiedIndex){
            return get(indices).assign(element);
        } else {
            //Can't get a view, so we'll do it in subsets instead
            // This is inefficient, but it is correct...
            int numSpecified = 0;
            List specifiedIdxs = new ArrayList<>();
            List specifiedIdxDims = new ArrayList<>();

            INDArrayIndex[] destinationIndices = indices.clone();  //Shallow clone
            INDArrayIndex[] sourceIndices = indices.clone();
            for( int i=0; i can't use point(1) on [1,x,y]
                    sourceIndices[i] = NDArrayIndex.point(0);
                }
            }
            int[] counts = new int[specifiedIdxs.size()];
            int[] dims = new int[specifiedIdxDims.size()];
            for( int i=0; i= this.length())
            throw new ND4JIllegalStateException("Index can't be greater then array length");

        if (i < 0)
            i += this.length();

        long idx = this.isScalar() ? 0 : Shape.getOffset(jvmShapeInfo.javaShapeInformation, Shape.ind2subC(this.shape(), i));
        val buffer = Nd4j.createBuffer( this.data(), this.data().originalOffset() + idx, 1);
        val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1,'c', dataType(), false);
        return Nd4j.createArrayFromShapeBuffer(buffer, shape);
    }

    /**
     * Do a row wise op (a,s,m,d)
     * a : add
     * s : subtract
     * m : multiply
     * d : divide
     * h : reverse subtraction
     * t : reverse division
     *
     * @param columnVector the column  vector
     * @param operation    the operation
     * @return
     */
    protected INDArray doColumnWise(INDArray columnVector, char operation) {
        Nd4j.getCompressor().autoDecompress(this);
       if(columnVector.isScalar()) {
           switch (operation) {
               case 'a':
                   addi(columnVector.getDouble(0));
                   break;
               case 'p':
                   assign(columnVector.getDouble(0));
                   break;
               case 's':
                   subi(columnVector.getDouble(0));
                   break;
               case 'm':
                   muli(columnVector.getDouble(0));
                   break;
               case 'd':
                   divi(columnVector.getDouble(0));
                   break;
               case 'h':
                   rsubi(columnVector.getDouble(0));
                   break;
               case 't':
                   rdivi(columnVector.getDouble(0));
                   break;

           }

           return this;
       }

       else if(isScalar()) {
           switch (operation) {
               case 'a':
                   return columnVector.addi(getDouble(0));
               case 'p':
                   return columnVector.assign(getDouble(0));
               case 's':
                   return columnVector.subi(getDouble(0));
               case 'm':
                   return columnVector.muli(getDouble(0));
               case 'd':
                   return columnVector.divi(getDouble(0));
               case 'h':
                   return columnVector.rsubi(getDouble(0));
               case 't':
                   return columnVector.rdivi(getDouble(0));

           }
       }

        //Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0)
        //Or, simply require it to be a rank 1 vector
        if ((!columnVector.isColumnVector() && columnVector.rank() > 1) || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) {
            throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape())
                    + ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")");
        }

        if (columnVector.data().sameUnderlyingData(data()))
            return doColumnWise(columnVector.dup(), operation);
        if (equalShapes(columnVector)) {
            switch (operation) {
                case 'a':
                    addi(columnVector);
                    break;
                case 'p':
                    assign(columnVector);
                    break;
                case 's':
                    subi(columnVector);
                    break;
                case 'm':
                    muli(columnVector);
                    break;
                case 'd':
                    divi(columnVector);
                    break;
                case 'h':
                    rsubi(columnVector);
                    break;
                case 't':
                    rdivi(columnVector);
                    break;
            }

            return this;
        }
        if (rows() == 1 && columnVector.isScalar()) {
            applyScalarOp(columnVector, operation);
        } else {
            // special optimization case, broadcast turns into ScalarOp Along Dimension
            if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'c' && columnVector.elementWiseStride() == 1) {
                switch (operation) {
                    case 'a': {
                        ScalarAdd op = new ScalarAdd(this, columnVector, this, 0.0);
                        op.setDimension(1);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 'p': {
                        ScalarSet op = new ScalarSet(this, columnVector, this, 0.0);
                        op.setDimension(1);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 's': {
                        ScalarSubtraction op = new ScalarSubtraction(this, columnVector, this, 0.0);
                        op.setDimension(1);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 'm': {
                        ScalarMultiplication op =
                                new ScalarMultiplication(this, columnVector, this, 0.0);
                        op.setDimension(1);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 'd': {
                        ScalarDivision op = new ScalarDivision(this, columnVector, this, 0.0);
                        op.setDimension(1);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 'h': {
                        ScalarReverseSubtraction op =
                                new ScalarReverseSubtraction(this, columnVector, this, 0.0);
                        op.setDimension(1);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 't': {
                        ScalarReverseDivision op =
                                new ScalarReverseDivision(this, columnVector, this, 0.0);
                        op.setDimension(1);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                }
            } else {
                applyBroadcastOp(columnVector, operation);
            }

        }

        return this;

    }

    /**
     * Do a row wise op (a,s,m,d)
     * a : add
     * s : subtract
     * m : multiply
     * d : divide
     * h : reverse subtraction
     * t : reverse division
     *
     * @param rowVector the row vector
     * @param operation the operation
     * @return
     */
    protected INDArray doRowWise(INDArray rowVector, final char operation) {
        Nd4j.getCompressor().autoDecompress(this);


        if(rowVector.isScalar()) {
            switch (operation) {
                case 'a':
                    addi(rowVector.getDouble(0));
                    break;
                case 'p':
                    assign(rowVector.getDouble(0));
                    break;
                case 's':
                    subi(rowVector.getDouble(0));
                    break;
                case 'm':
                    muli(rowVector.getDouble(0));
                    break;
                case 'd':
                    divi(rowVector.getDouble(0));
                    break;
                case 'h':
                    rsubi(rowVector.getDouble(0));
                    break;
                case 't':
                    rdivi(rowVector.getDouble(0));
                    break;

            }

            return this;
        }
        else if(isScalar()) {
            switch (operation) {
                case 'a':
                    return rowVector.addi(getDouble(0));
                case 'p':
                    return rowVector.assign(getDouble(0));
                case 's':
                    return rowVector.subi(getDouble(0));
                case 'm':
                    return rowVector.muli(getDouble(0));
                case 'd':
                    return rowVector.divi(getDouble(0));
                case 'h':
                    return rowVector.rsubi(getDouble(0));
                case 't':
                    return rowVector.rdivi(getDouble(0));

            }
        }

        //Input validation: require (a) rowVector to actually be a row vector, and (b) this.size(1) to match rowVector.size(1)
        if (!rowVector.isRowVector() || this.rank() > 1 && rowVector.rank() > 1 &&  this.size(1) != rowVector.size(1) || rowVector.length() <= 1) {
            throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape())
                    + ", row vector shape =" + Arrays.toString(rowVector.shape()) + ")");
        }

        if (rowVector.data().sameUnderlyingData(data()))
            return doRowWise(rowVector.dup(), operation);

        if (isVector()) {
            switch (operation) {
                case 'a':
                    addi(rowVector);
                    break;
                case 'p':
                    assign(rowVector);
                    break;
                case 's':
                    subi(rowVector);
                    break;
                case 'm':
                    muli(rowVector);
                    break;
                case 'd':
                    divi(rowVector);
                    break;
                case 'h':
                    rsubi(rowVector);
                    break;
                case 't':
                    rdivi(rowVector);
                    break;
            }

            return this;
        }

        if (rank() == 2 && columns() == 1 && rowVector.isScalar()) {
            applyScalarOp(rowVector, operation);
        } else {
            // special optimization case, broadcast turns into ScalarOp Along Dimension
            if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'f' && rowVector.elementWiseStride() == 1) {
                switch (operation) {
                    case 'a': {
                        ScalarAdd op = new ScalarAdd(this, rowVector, this, 0.0);
                        op.setDimension(0);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 'p': {
                        ScalarSet op = new ScalarSet(this, rowVector, this, 0.0);
                        op.setDimension(0);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 's': {
                        ScalarSubtraction op = new ScalarSubtraction(this, rowVector, this, 0.0);
                        op.setDimension(0);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 'm': {
                        ScalarMultiplication op = new ScalarMultiplication(this, rowVector, this, 0.0);
                        op.setDimension(0);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 'd': {
                        ScalarDivision op = new ScalarDivision(this, rowVector, this, 0.0);
                        op.setDimension(0);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 'h': {
                        ScalarReverseSubtraction op =
                                new ScalarReverseSubtraction(this, rowVector, this, 0.0);
                        op.setDimension(0);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }
                    case 't': {
                        ScalarReverseDivision op = new ScalarReverseDivision(this, rowVector, this, 0.0);
                        op.setDimension(0);
                        Nd4j.getExecutioner().exec(op);
                        break;
                    }

                }
            } else {
                applyBroadcastOp(rowVector, operation);
            }
        }

        return this;
    }


    private void applyBroadcastOp(INDArray vector, final char operation) {
        Nd4j.getCompressor().autoDecompress(this);
        int alongDimension = Shape.isRowVectorShape(vector.shape()) ? 1 : 0;

        // FIXME: probably this is wrong, because strict equality is always false in current DataBuffer mechanics
        if (this.data() == vector.data())
            vector = vector.dup();
        switch (operation) {
            case 'a':
                Nd4j.getExecutioner().exec(new BroadcastAddOp(this, vector, this, alongDimension));
                return;
            case 's':
                Nd4j.getExecutioner().exec(new BroadcastSubOp(this, vector, this, alongDimension));
                return;
            case 'm':
                Nd4j.getExecutioner().exec(new BroadcastMulOp(this, vector, this, alongDimension));
                return;
            case 'd':
                Nd4j.getExecutioner().exec(new BroadcastDivOp(this, vector, this, alongDimension));
                return;
            case 'h':
                Nd4j.getExecutioner().exec(new BroadcastRSubOp(this, vector, this, alongDimension));
                return;
            case 't':
                Nd4j.getExecutioner().exec(new BroadcastRDivOp(this, vector, this, alongDimension));
                return;
            case 'p':
                Nd4j.getExecutioner().exec(new BroadcastCopyOp(this, vector, this, alongDimension));
                return;
            default:
                throw new UnsupportedOperationException("Unknown operation: " + operation);
        }
    }

    private void applyScalarOp(INDArray vector, char operation) {
        Nd4j.getCompressor().autoDecompress(this);
        switch (operation) {
            case 'a':
                addi(vector.getDouble(0));
                break;
            case 's':
                subi(vector.getDouble(0));
                break;
            case 'm':
                muli(vector.getDouble(0));
                break;
            case 'd':
                divi(vector.getDouble(0));
                break;
            case 'h':
                rsubi(vector.getDouble(0));
                break;
            case 't':
                rdivi(vector.getDouble(0));
                break;
        }
    }

    protected DataBuffer shapeOf() {
        //        if (shape == null)
        //            shape = Shape.shapeOf(shapeInfoDataBuffer());
        //        return shape;

        return Shape.shapeOf(shapeInfoDataBuffer());
    }

    protected DataBuffer strideOf() {
        //        if (stride == null)
        //            stride = Shape.stride(shapeInfoDataBuffer());
        //        return stride;
        return Shape.stride(shapeInfoDataBuffer());
    }

    @Override
    public int stride(int dimension) {
        int rank = jvmShapeInfo.rank;
        Preconditions.checkArgument(dimension < rank, "Cannot get stride for dimension %s from rank %s array: " +
                "dimension indices must be in range -rank <= dimension < rank", dimension, rank);
        if (dimension < 0)
            return (int) stride()[dimension + rank];
        return (int) stride()[dimension];
    }

    @Override
    public INDArray rdiviColumnVector(INDArray columnVector) {
        validateNumericalArray("rdiviColumnVector", false);
        return doColumnWise(columnVector, 't');
    }

    @Override
    public INDArray rdivColumnVector(INDArray columnVector) {
        validateNumericalArray("rdivColumnVector", false);
        return dup().rdiviColumnVector(columnVector);
    }

    @Override
    public INDArray rdiviRowVector(INDArray rowVector) {
        validateNumericalArray("rdiviRowVector", false);
        return doRowWise(rowVector, 't');
    }

    @Override
    public INDArray rdivRowVector(INDArray rowVector) {
        validateNumericalArray("rdivRowVector", false);
        return dup().rdiviRowVector(rowVector);
    }

    @Override
    public INDArray rsubiColumnVector(INDArray columnVector) {
        validateNumericalArray("rsubiColumnVector", false);
        return doColumnWise(columnVector, 'h');
    }

    @Override
    public INDArray rsubColumnVector(INDArray columnVector) {
        validateNumericalArray("rsubColumnVector", false);
        return dup().rsubiColumnVector(columnVector);
    }

    @Override
    public INDArray rsubiRowVector(INDArray rowVector) {
        validateNumericalArray("rsubiRowVector", false);
        return doRowWise(rowVector, 'h');
    }

    @Override
    public INDArray rsubRowVector(INDArray rowVector) {
        validateNumericalArray("rsubRowVector", false);
        return dup().rsubiRowVector(rowVector);
    }

    @Override
    public INDArray put(int i, INDArray element) {
        Preconditions.checkArgument(element.isScalar(), "Element must be a scalar: element has shape %ndShape", element);
        return putScalar(i, element.getDouble(0));
    }

    @Override
    public INDArray diviColumnVector(INDArray columnVector) {
        validateNumericalArray("diviColumnVector", false);
        return doColumnWise(columnVector, 'd');
    }

    @Override
    public INDArray divColumnVector(INDArray columnVector) {
        validateNumericalArray("divColumnVector", false);
        return dup().diviColumnVector(columnVector);
    }

    @Override
    public INDArray diviRowVector(INDArray rowVector) {
        validateNumericalArray("diviRowVector", false);
        return doRowWise(rowVector, 'd');
    }

    @Override
    public INDArray divRowVector(INDArray rowVector) {
        validateNumericalArray("divRowVector", false);
        return dup().diviRowVector(rowVector);
    }

    @Override
    public INDArray muliColumnVector(INDArray columnVector) {
        validateNumericalArray("muliColumnVector", false);
        return doColumnWise(columnVector, 'm');
    }

    @Override
    public INDArray mulColumnVector(INDArray columnVector) {
        validateNumericalArray("mulColumnVector", false);
        return dup().muliColumnVector(columnVector);
    }

    @Override
    public INDArray muliRowVector(INDArray rowVector) {
        validateNumericalArray("muliRowVector", false);
        return doRowWise(rowVector, 'm');
    }

    @Override
    public INDArray mulRowVector(INDArray rowVector) {
        validateNumericalArray("mulRowVector", false);
        return dup().muliRowVector(rowVector);
    }

    @Override
    public INDArray subiColumnVector(INDArray columnVector) {
        validateNumericalArray("subiColumnVector", false);
        return doColumnWise(columnVector, 's');
    }

    @Override
    public INDArray subColumnVector(INDArray columnVector) {
        validateNumericalArray("subColumnVector", false);
        return dup().subiColumnVector(columnVector);
    }

    @Override
    public INDArray subiRowVector(INDArray rowVector) {
        validateNumericalArray("subiRowVector", false);
        return doRowWise(rowVector, 's');
    }

    @Override
    public INDArray subRowVector(INDArray rowVector) {
        validateNumericalArray("subRowVector", false);
        return dup().subiRowVector(rowVector);
    }

    @Override
    public INDArray addiColumnVector(INDArray columnVector) {
        validateNumericalArray("addiColumnVector", false);
        return doColumnWise(columnVector, 'a');
    }

    @Override
    public INDArray putiColumnVector(INDArray columnVector) {
        return doColumnWise(columnVector, 'p');
    }

    @Override
    public INDArray addColumnVector(INDArray columnVector) {
        validateNumericalArray("addColumnVector", false);
        return dup().addiColumnVector(columnVector);
    }

    @Override
    public INDArray addiRowVector(INDArray rowVector) {
        validateNumericalArray("addiRowVector", false);
        return doRowWise(rowVector, 'a');
    }

    @Override
    public INDArray putiRowVector(INDArray rowVector) {
        validateNumericalArray("putiRowVector", false);
        return doRowWise(rowVector, 'p');
    }

    @Override
    public INDArray addRowVector(INDArray rowVector) {
        validateNumericalArray("addRowVector", false);
        return dup().addiRowVector(rowVector);
    }

    @Override
    public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
        return mMulTranspose.exec(this, other, result);
    }

    @Override
    public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
        return mMulTranspose.exec(this, other, null);
    }

    @Override
    public INDArray mmul(INDArray other, char resultOrder) {
        Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', "Order must be either 'c' or 'f', but [" + resultOrder + "] was given");
        Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType());
        // FIXME: add support for 3D+ here?
        long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()};
        INDArray result = createUninitialized(this.dataType(), shape, resultOrder);
        if (result.isScalar())
            return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1);
        return mmuli(other, result);
    }

    @Override
    public INDArray mmul(INDArray other) {
        return mmul(other, (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c');
    }

    protected INDArray create(int[] shape, char ordering) {
        return Nd4j.create(shape, ordering);
    }

    @Override
    public double[][] toDoubleMatrix() {
        if(!isMatrix()) {
            throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this));
        }

        if (this.size(0) > Integer.MAX_VALUE || this.size(1) > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        double[][] ret = new double[rows()][columns()];
        for(int i = 0; i < ret.length; i++) {
            ret[i] = getRow(i).dup().data().asDouble();
        }

        return ret;
    }

    @Override
    public double[] toDoubleVector() {
        if(!isVectorOrScalar()) {
            throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
        }
        return dup().data().asDouble();
    }

    @Override
    public float[] toFloatVector() {
        if(!isVectorOrScalar()) {
            throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
        }
        return dup().data().asFloat();
    }

    @Override
    public float[][] toFloatMatrix() {
        if(!isMatrix()) {
            throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this));
        }

        if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        float[][] ret = new float[(int) rows()][ (int) columns()];
        for(int i = 0; i < ret.length; i++) {
            ret[i] = getRow(i).dup().data().asFloat();
        }

        return ret;
    }

    @Override
    public int[] toIntVector() {
        if (isEmpty())
            return new int[0];

        if(!isVectorOrScalar()) {
            throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
        }
        if(isView() || elementWiseStride() != 1){
            return dup().data().asInt();
        }
        return data().asInt();
    }

    @Override
    public long[] toLongVector() {
        if(!isVectorOrScalar()) {
            throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
        }
        if(isView() || elementWiseStride() != 1){
            return dup().data().asLong();
        }
        return data().asLong();
    }

    @Override
    public long[][] toLongMatrix() {
        if(!isMatrix()) {
            throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this));
        }

        if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        long[][] ret = new long[(int) rows()][(int) columns()];
        for(int i = 0; i < ret.length; i++) {
            ret[i] = getRow(i).dup().data().asLong();
        }

        return ret;
    }

    @Override
    public int[][] toIntMatrix() {
        if(!isMatrix()) {
            throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this));
        }

        if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        int[][] ret = new int[(int) rows()][(int) columns()];
        for(int i = 0; i < ret.length; i++) {
            ret[i] = getRow(i).dup().data().asInt();
        }

        return ret;
    }

    /**
     * Perform an copy matrix multiplication
     *
     * @param other  the other matrix to perform matrix multiply with
     * @param result the result ndarray
     * @return the result of the matrix multiplication
     */
    @Override
    public INDArray mmul(INDArray other, INDArray result) {
        return mmuli(other, result);
    }

    @Override
    public INDArray div(INDArray other) {
        if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
            return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
        } else {
            return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering()));
        }
    }

    @Override
    public INDArray div(INDArray other, INDArray result) {
        validateNumericalArray("div", true);
        return divi(other, result);
    }

    @Override
    public INDArray mul(INDArray other) {
        validateNumericalArray("mul", false);
        if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
            return muli(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
        } else {
            val z = Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering());
            return muli(other, z);
        }
    }

    @Override
    public INDArray mul(INDArray other, INDArray result) {
        return muli(other, result);
    }

    @Override
    public INDArray sub(INDArray other) {
        validateNumericalArray("sub", false);
        if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
            return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
        } else {
            return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering()));
        }
    }

    @Override
    public INDArray sub(INDArray other, INDArray result) {
        return subi(other, result);
    }

    @Override
    public INDArray add(INDArray other) {
        validateNumericalArray("add", false);
        if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
            return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
        } else {
            return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering()));
        }
    }

    @Override
    public INDArray add(INDArray other, INDArray result) {
        validateNumericalArray("add", false);
        return addi(other, result);
    }

    @Override
    public INDArray mmuli(INDArray other, MMulTranspose transpose) {
        validateNumericalArray("mmuli", false);
        return dup().mmuli(other, this,transpose);
    }

    @Override
    public INDArray mmuli(INDArray other) {
        validateNumericalArray("mmuli", false);
        return dup().mmuli(other, this);
    }

    @Override
    public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
        return transpose.exec(this, other, result);
    }

    @Override
    public INDArray mmuli(INDArray other, INDArray result) {
        validateNumericalArray("mmuli", false);
        LinAlgExceptions.assertMultiplies(this, other);
        if(other.rank() == 1){
            //GEMV edge case
            Preconditions.checkState(result.length() == this.size(0) && this.size(1) == other.size(0),
                    "Invalid matrix multiplication: %ndShape x %ndShape with result shape %ndShape", this, other, result);
        } else {
            //Standard case
            Preconditions.checkState(
                    result.rank() == 2 && result.size(0) == this.size(0) && result.size(1) == other.size(1),
                    "Invalid result array shape: expected shape [%s,%s], got shape %ndShape result array for %ndShape x %ndShape", this.size(0), other.size(1), result,
                    this, other);
        }

        if (other.isScalar()) {
            return muli(other.getDouble(0), result);
        }
        if (isScalar()) {
            return other.muli(getDouble(0), result);
        }

        /* check sizes and resize if necessary */


        if (result == this || result == other) {
            /* actually, blas cannot do multiplications in-place. Therefore, we will fake by
             * allocating a temporary object on the side and copy the result later.
             */
            INDArray temp = Nd4j.create(result.dataType(), result.shape(), Nd4j.getStrides(result.shape(), 'f'), 'f');

            if (other.columns() == 1 || other.rank() == 1) {
                Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(result),
                        BlasBufferUtil.getCharForTranspose(this), 1.0, this, other, 0.0, temp);
            }

            else {
                Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(result),
                        BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(temp), 1.0,
                        this, other, 0.0, temp);
            }

            result.assign(temp);


        } else {

            //We require that the result array is 'f' (fortran) order
            // However, user might have called mmuli with a c order array for the result
            // In which case, we need to allocate a temporary f order array, and later do an assign to the real result array

            boolean requiresTemp = result.ordering() != 'f' || result.isView() || !Shape.hasDefaultStridesForShape(result);
            INDArray gemmResultArr;
            if (requiresTemp) {
                //Can use createUninitialized due to beta==0.0 parameter in gemm
                gemmResultArr = Nd4j.createUninitialized(result.dataType(), result.shape(), 'f');
            } else {
                gemmResultArr = result;
            }

            if (other.columns() == 1 || other.rank() == 1) {
                Nd4j.getBlasWrapper().level2().gemv(
                        ordering(),
                        BlasBufferUtil.getCharForTranspose(other),
                        1.0,
                        this,
                        other,
                        0.0,
                        gemmResultArr);
            } else {
                //gemm doesn't support strides so vectors and views
                //don't work
                Nd4j.getBlasWrapper().level3().gemm(ordering(),
                        BlasBufferUtil.getCharForTranspose(other),
                        BlasBufferUtil.getCharForTranspose(gemmResultArr),
                        1.0,
                        this,
                        other,
                        0.0,
                        gemmResultArr);
            }

            if (requiresTemp) {
                result.assign(gemmResultArr);
            }
        }

        // 1D edge case: reshape back to vector
        if (other.rank() == 1)
            result = result.reshape(result.length());
        return result;
    }

    private INDArray create(int[] shape, int[] stride) {
        return Nd4j.create(shape, stride);
    }

    @Override
    public INDArray divi(INDArray other) {
        return divi(other, this);
    }

    @Override
    public INDArray divi(INDArray other, INDArray result) {
        validateNumericalArray("divi", false);
        Shape.assertBroadcastable("divi", this, other, result);
        Nd4j.exec(new DivOp(this, other, result));
        return result;
    }

    @Override
    public INDArray muli(INDArray other) {
        return muli(other, this);
    }

    @Override
    public INDArray muli(INDArray other, INDArray result) {
        validateNumericalArray("muli", false);
        Shape.assertBroadcastable("muli", this, other, result);
        Nd4j.exec(new MulOp(this, other, result));
        return result;
    }

    @Override
    public INDArray subi(INDArray other) {
        return subi(other, this);
    }

    /**
     * in place subtraction of two matrices
     *
     * @param other  the second ndarray to subtract
     * @param result the result ndarray
     * @return the result of the subtraction
     */
    @Override
    public INDArray subi(INDArray other, INDArray result) {
        validateNumericalArray("subi", false);
        Shape.assertBroadcastable("subi", this, other, result);
        Nd4j.exec(new SubOp(this, other, result));
        return result;
    }

    @Override
    public INDArray addi(INDArray other) {
        return addi(other, this);
    }

    @Override
    public INDArray addi(INDArray other, INDArray result) {
        validateNumericalArray("addi", false);
        Shape.assertBroadcastable("addi", this, other, result);
        Nd4j.exec(new AddOp(this, other, result));
        return result;
    }

    @Override
    public INDArray normmax(boolean keepDims, int... dimension) {
        validateNumericalArray("normmax", false);
        return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension));
    }

    @Override
    public INDArray normmax(int... dimension) {
        return normmax(false, dimension);
    }

    @Override
    public INDArray rdiv(INDArray other) {
        validateNumericalArray("rdiv", false);
        if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
            return rdivi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
        } else {
            return rdivi(other, this.ulike());
        }
    }

    @Override
    public INDArray rdivi(INDArray other) {
        return rdivi(other, this);
    }

    @Override
    public INDArray rdiv(INDArray other, INDArray result) {
        validateNumericalArray("rdiv", false);
        return dup().rdivi(other, result);
    }

    @Override
    public INDArray rdivi(INDArray other, INDArray result) {
        validateNumericalArray("rdivi", false);
        Shape.assertBroadcastable("rdivi", this, other, result);
        Nd4j.exec(new RDivOp(this, other, result));
        return result;
    }

    @Override
    public INDArray rsub(INDArray other, INDArray result) {
        validateNumericalArray("rsub", false);
        return rsubi(other, result);
    }

    @Override
    public INDArray rsub(INDArray other) {
        validateNumericalArray("rsub", false);
        if (Shape.areShapesBroadcastable(this.shape(), other.shape())) {
            return rsubi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering()));
        } else {
            return rsubi(other, this.ulike());
        }
    }

    @Override
    public INDArray rsubi(INDArray other) {
        return rsubi(other, this);
    }

    @Override
    public INDArray rsubi(INDArray other, INDArray result) {
        validateNumericalArray("rsubi", false);
        Shape.assertBroadcastable("rsubi", this, other, result);
        Nd4j.exec(new RSubOp(this, other, result));
        return result;
    }

    @Override
    public INDArray assign(Number value) {
        Preconditions.checkState(dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, "Only values 0 or 1 are allowed for scalar " +
                "assign on boolean arrays: got value %s on to assign to boolean array with shape %ndShape", value, this);
        Nd4j.getExecutioner().exec(new ScalarSet(this, value));
        return this;
    }

    @Override
    public INDArray assign(boolean value) {
        return assign(value ? 1 : 0);
    }

    @Override
    public INDArray assignIf(INDArray arr, Condition condition) {
        BooleanIndexing.assignIf(this, arr, condition);
        return this;
    }

    @Override
    public INDArray replaceWhere(INDArray arr, Condition condition) {
        Nd4j.getCompressor().autoDecompress(this);
        BooleanIndexing.replaceWhere(this, arr, condition);
        return this;
    }

    @Override
    @Deprecated  //TODO: Investigate. Not deprecated in the base interface.
    public long linearIndex(long i) {
        long idx = i;
        for (int j = 0; j < jvmShapeInfo.rank - 1; j++) {
            if (size((int) i) == 1)
                continue;
            idx += i * stride(j);
        }
        return Shape.offset(jvmShapeInfo.javaShapeInformation) + (idx);
    }

    @Override
    public INDArray slice(long slice) {
        Nd4j.getCompressor().autoDecompress(this);


        long slices = slices();
        if (slice >= slices)
            throw new IllegalArgumentException("Illegal slice " + slice);

        if (jvmShapeInfo.rank == 0 ) {
            throw new IllegalArgumentException("Can't slice a 0-d NDArray");
        }


        if (slice < 0)
            slice += rank();
        INDArrayIndex[] indexes = new INDArrayIndex[rank()];
        indexes[0] = NDArrayIndex.point(slice);
        for (int i = 1; i < rank(); i++) {
            indexes[i] = NDArrayIndex.all();
        }
        return get(indexes);
    }



    protected INDArray createScalarForIndex(long i, boolean applyOffset) {
        if(isVector())
            return getScalar(i);
        return Nd4j.create(data(), new long[] {1, 1}, new long[] {1, 1}, i);
    }

    protected INDArray createScalar(double d) {
        return Nd4j.scalar(d);
    }

    @Override
    public int getTrailingOnes() {
        int numLeadingOnes = 0;
        for (int i = rank() - 1; i > 0; i--) {
            if (size(i) == 1)
                numLeadingOnes++;
        }

        return numLeadingOnes;
    }

    @Override
    public int getLeadingOnes() {
        int numLeadingOnes = 0;
        for (int i = 0; i < rank(); i++) {
            if (size(i) == 1)
                numLeadingOnes++;
        }

        return numLeadingOnes;
    }

    @Override
    public INDArray slice(long slice, int dimension) {
        Nd4j.getCompressor().autoDecompress(this);

        long slices = size(dimension);
        if (slice >= slices)
            throw new IllegalArgumentException("Illegal slice " + slice);

        if (jvmShapeInfo.rank == 0) {
            if (slice == 0)
                return createScalarForIndex(slice, true);
            else
                throw new IllegalArgumentException("Can't slice a 0-d NDArray");

        }


        if (slice < 0)
            slice += rank();
        INDArrayIndex[] indexes = new INDArrayIndex[rank()];
        indexes[dimension] = NDArrayIndex.point(slice);
        for (int i = 0; i < rank(); i++) {
            if (i != dimension)
                indexes[i] = NDArrayIndex.all();
        }
        return get(indexes);

    }

    @Override
    public INDArray getScalar(int[] indexes) {
        if (indexes.length > rank())
            throw new ND4JIllegalStateException("Indexes can't be longer then array rank");

        for (int i = 0; i < indexes.length; i++) {
            if (indexes[i] < 0)
                indexes[i] += this.size(i);
        }
        long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes);
        val buffer = Nd4j.createBuffer(this.data(), idx, 1);
        val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1, 'c', this.dataType(), false);
        return Nd4j.createArrayFromShapeBuffer(buffer, shape);
    }

    @Override
    public INDArray getScalar(long... indexes) {
        if (indexes.length > rank())
            throw new ND4JIllegalStateException("Indexes can't be longer then array rank");

        for (int i = 0; i < indexes.length; i++) {
            if (indexes[i] < 0)
                indexes[i] += this.size(i);
        }

        long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes);
        val buffer = Nd4j.createBuffer(this.data(), idx, 1);
        val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1,'c', this.dataType(), false);
        return Nd4j.createArrayFromShapeBuffer(buffer, shape);
    }

    @Override
    public INDArray rdiv(Number n) {
        //return dup().rdivi(n);
        return rdivi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), this.ordering()));
    }

    @Override
    public INDArray rdivi(Number n) {
        return rdivi(n, this);
    }

    @Override
    public INDArray rsub(Number n) {
        validateNumericalArray("rsub", false);
        return rsubi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering()));
    }

    @Override
    public INDArray rsubi(Number n) {
        return rsubi(n, this);
    }

    @Override
    public INDArray div(Number n) {
        validateNumericalArray("div", false);
        return divi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering()));
    }

    @Override
    public INDArray divi(Number n) {
        return divi(n, this);
    }

    @Override
    public INDArray mul(Number n) {
        validateNumericalArray("mul", false);
        return muli(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), this.ordering()));
    }

    @Override
    public INDArray muli(Number n) {
        return muli(n, this);
    }

    @Override
    public INDArray sub(Number n) {
        validateNumericalArray("sub", false);
        return subi(n, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()));
    }

    @Override
    public INDArray subi(Number n) {
        return subi(n, this);
    }

    @Override
    public INDArray add(Number n) {
        validateNumericalArray("add", false);
        return addi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering()));
    }

    @Override
    public INDArray addi(Number n) {
        return addi(n, this);
    }

    @Override
    public INDArray repmat(long[] shape) {
        Nd4j.getCompressor().autoDecompress(this);
        long rows = rows() * shape[0];
        long cols = columns() * shape[1];
        INDArray ret = reshape(1, length()).repeat(0, shape[0]).reshape(rows, columns()).repeat(0, shape[1]);
        return ret.reshape(rows, cols);
    }

    @Deprecated
    @Override
    public INDArray repmat(int[] shape) {
        long[] longShape = ArrayUtil.toLongArray(shape);
        return repmat(longShape);
    }

    @Override
    public INDArray repeat(int dimension, long... repeats) {
        Nd4j.getCompressor().autoDecompress(this);
        CustomOp op = DynamicCustomOp.builder("repeat")
                .addInputs(this)
                .addIntegerArguments(ArrayUtil.toInts(repeats))     //TODO int cast
                .build();
        op.addIArgument(dimension); //Native op: last iarg is dimension

        LongShapeDescriptor l = op.calculateOutputShape().get(0);
        INDArray out = Nd4j.create(l);
        op.addOutputArgument(out);
        Nd4j.exec(op);
        return out;
    }

    @Override
    public INDArray putRow(long row, INDArray toPut) {
        if (isRowVector() && toPut.isVector()) {
            return assign(toPut);
        }
        return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut);
    }

    @Override
    public INDArray putColumn(int column, INDArray toPut) {
        Nd4j.getCompressor().autoDecompress(this);

        if (isColumnVector() && toPut.isVector()) {
            return assign(toPut);
        }
        return put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(column)}, toPut);
    }

    @Override
    public Number getNumber(long i){
        switch (dataType()){
            case DOUBLE:
            case FLOAT:
            case HALF:
            case BFLOAT16:
                return getDouble(i);
            case LONG:
            case INT:
            case SHORT:
            case UBYTE:
            case BYTE:
            case BOOL:
            case UINT64:
            case UINT32:
            case UINT16:
                return getLong(i);
            case UTF8:
            case COMPRESSED:
            case UNKNOWN:
            default:
                throw new UnsupportedOperationException("Cannot get number from array of datatype: " + dataType());
        }
    }

    @Override
    public Number getNumber(long... idx){
        switch (dataType()){
            case DOUBLE:
            case FLOAT:
            case HALF:
                return getDouble(idx);
            case LONG:
            case INT:
            case SHORT:
            case UBYTE:
            case BYTE:
            case BOOL:
                return getLong(idx);
            case UTF8:
            case COMPRESSED:
            case UNKNOWN:
            default:
                throw new UnsupportedOperationException("Cannot get number from array of datatype: " + dataType());
        }
    }

    @Override
    public double getDouble(long i) {
        Nd4j.getCompressor().autoDecompress(this);
        Preconditions.checkState(!isEmpty(), "Unable to get value from empty array");

        if (i >= length()) {
            throw new IllegalArgumentException("Unable to get linear index " + i + ": values is greater than length (" + length() + ")");
        }

        autoProcessScalarCall();

        if (i == 0)
            return data().getDouble(i);

        long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
        Shape.assertShapeLessThan(dimensions, shape());
        return getDouble(dimensions);

    }

    @Override
    public double getDouble(long i, long j) {
        return getDouble(new long[] {i, j});
    }

    @Override
    public float getFloat(long i) {
        return (float) getDouble(i);
    }

    @Override
    public float getFloat(long i, long j) {
        return (float) getDouble(i, j);
    }

    @Override
    public INDArray transpose() {
        Preconditions.checkState(rank() >= 2, "Can't transpose array with rank < 2: array shape %ndShape", this);

        return permute(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank())));
    }

    /**
     *
     * Return transposed version of this matrix.
     *
     * PLEASE NOTE: This method is NOT in place, it will return transposed copy instead.
     */
    @Override
    public INDArray transposei() {
        Preconditions.checkState(rank() >= 2, "Can't transpose array with rank < 2: array shape %ndShape", this);
        
        return permutei(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank())));
    }

    protected INDArray create(DataBuffer data, int[] shape, int[] strides) {
        return Nd4j.create(data, shape, strides, 0, ordering());
    }

    @Deprecated
    @Override
    public INDArray reshape(char order, int... newShape) {
        return reshape(order, ArrayUtil.toLongArray(newShape));
    }

    @Override
    public INDArray reshape(char order, long... newShape) {
        return reshape(order, false, newShape);
    }

    @Override
    public INDArray reshape(char order, boolean enforceView, long... newShape){
        Nd4j.getCompressor().autoDecompress(this);

        // special case for empty reshape
        if (this.length() == 1 && (newShape == null || newShape.length == 0) && this.elementWiseStride() == 1) {
            return Nd4j.create(this.data(), new int[0], new int[0], 0);
        }

        if (newShape == null || newShape.length < 1)
            throw new ND4JIllegalStateException(
                    "Can't reshape(long...) without shape arguments. Got empty shape instead.");

        // TODO: maybe toFlatten() makes more sense here?
        // reshape(-1) special case
        if (newShape.length == 1 && newShape[0] == -1)
            newShape[0] = this.length();

        int numberNegativesOnes = 0;
        long[] shape = ArrayUtil.copy(newShape);


        for (int i = 0; i < shape.length; i++) {
            if (shape[i] < 0) {
                if (numberNegativesOnes >= 1)
                    throw new IllegalArgumentException("Only one dimension can be negative ones. Got shape "
                            + Arrays.toString(newShape));

                numberNegativesOnes++;

                int shapeLength = 1;
                for (int j = 0; j < shape.length; j++)
                    if (shape[j] >= 1)
                        shapeLength *= shape[j];
                long realShape = Math.abs(length() / shapeLength);
                long[] thisNewShape = new long[shape.length];
                for (int j = 0; j < shape.length; j++) {
                    if (i != j) {
                        thisNewShape[j] = shape[j];
                    } else
                        thisNewShape[j] = realShape;
                }

                shape = thisNewShape;
                break;

            }
        }

        long prod = ArrayUtil.prodLong(shape);

        if (prod != this.length()) {
            throw new ND4JIllegalStateException("New shape length doesn't match original length: [" + prod + "] vs [" + this.length() + "]. Original shape: "+Arrays.toString(this.shape())+" New Shape: "+Arrays.toString(newShape));
        }





        INDArray reshapeAttempt = Shape.newShapeNoCopy(this, shape, order == 'f');
        if (reshapeAttempt != null) {
            // kinda strange get/set usage
            //  reshapeAttempt.setOrder(Shape.getOrder(reshapeAttempt));
            return reshapeAttempt;
        }

        if(enforceView){
            throw new ND4JIllegalStateException("Unable to reshape array as view, called with enforceView=true. " +
                    "Use enforceView=false to return a copy instead, or call reshape on a non-strided array. Array shape info: " + this.shapeInfoToString().replaceAll("\n",""));
        }


        if (order != ordering()) {
            INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order);
            ret.setData(dup(order).data());
            return ret;
        } else if (this.isEmpty()) {
            return Nd4j.create(this.dataType(), shape);
        } else {
            INDArray ret = this.dup(order);
            return Nd4j.create(ret.data(), shape);
        }
    }

    @Override
    public double getDoubleUnsafe(long offset) {
        return data().getDouble(offset);
    }

    @Override
    public INDArray putScalarUnsafe(long offset, double value) {
        autoProcessScalarCall();
        data().put(offset, value);
        return this;
    }

    @Override
    public INDArray reshape(char order, int rows, int columns) {
        return reshape(order, new long[] {rows, columns});
    }

    /**
     * Reshape the ndarray in to the specified dimensions,
     * possible errors being thrown for invalid shapes
     *
     * Note here that one dimension can be -1.
     * The dimension that is -1 will be inferred from the shape and
     * the length of the ndarray
     *
     * @param shape the shape of the ndarray.
     * @return the new reshaped nd array
     */

    @Override
    public INDArray reshape(int[] shape) {
        return reshape(Nd4j.order(), shape);
    }

    @Override
    public INDArray reshape(long... shape) {
        return reshape(Nd4j.order(), shape);
    }

    @Override
    public INDArray prod(boolean keepDims, int... dimension) {
        validateNumericalArray("prod", false);
        return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension));
    }

    @Override
    public INDArray prod(int... dimension) {
        return prod(false, dimension);
    }

    @Override
    public INDArray mean(boolean keepDims, int... dimension) {
        validateNumericalArray("mean", false);
        return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension));
    }

    @Override
    public INDArray mean(int... dimension) {
        return mean(false, dimension);
    }

    @Override
    public INDArray amean(int... dimension) {
        validateNumericalArray("amean", false);
        return Nd4j.getExecutioner().exec(new AMean(this, dimension));
    }

    @Override
    public INDArray mean(@NonNull INDArray result, boolean keepDims, int... dimension) {
        validateNumericalArray("mean", false);
        return Nd4j.getExecutioner().exec(new Mean(this, result, keepDims, dimension));
    }

    @Override
    public INDArray mean(@NonNull INDArray result, int... dimension) {
        return mean(result, false, dimension);
    }

    @Override
    public INDArray var(int... dimension) {
        validateNumericalArray("var", false);
        return Nd4j.getExecutioner().exec(new Variance(this, dimension));
    }

    @Override
    public INDArray var(boolean biasCorrected, int... dimension) {
        validateNumericalArray("var", false);
        return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension));
    }

    @Override
    public INDArray max(boolean keepDims, int... dimension) {
        validateNumericalArray("max", false);
        return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension));
    }

    @Override
    public INDArray max(int... dimension) {
        return max(false, dimension);
    }

    @Override
    public INDArray amax(int... dimension) {
        validateNumericalArray("amax", false);
        return Nd4j.getExecutioner().exec(new AMax(this, dimension));
    }

    @Override
    public INDArray min(boolean keepDims, int... dimension) {
        validateNumericalArray("min", false);
        return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension));
    }

    @Override
    public INDArray min(int... dimension) {
        return min(false, dimension);
    }

    @Override
    public INDArray amin(int... dimension) {
        validateNumericalArray("amin", false);
        return Nd4j.getExecutioner().exec(new AMin(this, dimension));
    }

    @Override
    public INDArray sum(int... dimension) {
        validateNumericalArray("sum", true);
        return Nd4j.getExecutioner().exec(new Sum(this, dimension));
    }

    @Override
    public INDArray sum(boolean keepDim, int... dimension) {
        validateNumericalArray("sum", true);
        return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension));
    }

    @Override
    public INDArray entropy(int... dimension) {
        validateNumericalArray("entropy", false);
        return Nd4j.getExecutioner().exec(new Entropy(this, dimension));
    }

    @Override
    public INDArray shannonEntropy(int... dimension) {
        validateNumericalArray("shannonEntropy", false);
        return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension));
    }

    @Override
    public INDArray logEntropy(int... dimension) {
        validateNumericalArray("logEntropy", false);
        return Nd4j.getExecutioner().exec(new LogEntropy(this, dimension));
    }

    @Override
    public INDArray sum(@NonNull INDArray result, boolean keepDims, int... dimension) {
        validateNumericalArray("sum", true);
        return Nd4j.getExecutioner().exec(new Sum(this, result, keepDims, dimension));
    }

    @Override
    public INDArray sum(@NonNull INDArray result, int... dimension) {
        return sum(result, false, dimension);
    }

    @Override
    public INDArray norm1(int... dimension) {
        return norm1(false, dimension);
    }

    @Override
    public INDArray norm1(boolean keepDims, int... dimension) {
        validateNumericalArray("norm1", false);
        return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension));
    }

    @Override
    public INDArray std(int... dimension) {
        return std(true, dimension);
    }

    @Override
    public INDArray std(boolean biasCorrected, int... dimension) {
        return std(biasCorrected, false, dimension);
    }

    @Override
    public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) {
        validateNumericalArray("std", false);
        return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected, keepDims, dimension));
    }

    @Override
    public Number stdNumber(boolean biasCorrected) {
        validateNumericalArray("stdNumber", false);
        return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0);
    }

    @Override
    public INDArray norm2(boolean keepDims, int... dimension) {
        validateNumericalArray("norm2", false);
        return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension));
    }

    @Override
    public INDArray norm2(int... dimension) {
        return norm2(false, dimension);
    }

    @Override
    public int columns() {
        if (isMatrix())
            return (int) size(1);
        else if (Shape.isColumnVectorShape(shape())) {
            return 1;
        } else if (Shape.isRowVectorShape(shape())) {
            return (int) length();
        }
        throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid");


    }

    @Override
    public int rows() {
        if (isMatrix())
            return (int) size(0);
        else if (Shape.isRowVectorShape(shape())) {
            return 1;
        } else if (Shape.isColumnVectorShape(shape())) {
            return (int) length();
        }

        throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid");
    }

    @Override
    public INDArray ravel(char ordering) {
        Nd4j.getCompressor().autoDecompress(this);
        if(ordering == this.ordering() && Shape.hasDefaultStridesForShape(this)){
            return reshape(ordering, length());
        }
        return dup(ordering).reshape(ordering, length());
    }

    @Override
    public INDArray ravel() {
        return reshape(length());
    }

    @Override
    public void sliceVectors(List list) {
        if (isVector())
            list.add(this);
        else {
            for (int i = 0; i < slices(); i++) {
                slice(i).sliceVectors(list);
            }
        }
    }

    @Override
    public INDArray reshape(long newRows, long newColumns) {
        return reshape(new long[] {newRows, newColumns});
    }

    @Override
    public INDArray getColumn(long c) {
        Nd4j.getCompressor().autoDecompress(this);

        if (isColumnVector() && c == 0)
            return this;
        else if (isColumnVector() && c > 0)
            throw new IllegalArgumentException("Illegal index for column");
        Preconditions.checkArgument(this.rank() == 2, "getColumn() can be called on 2D arrays only");
        return tensorAlongDimension(c, 0);
    }

    @Override
    public INDArray getColumn(long c, boolean keepDim) {
        INDArray col = getColumn(c);
        if(!keepDim)
            return col;
        return col.reshape(col.length(), 1);
    }

    @Override
    public INDArray getRows(int[] rindices) {
        Nd4j.getCompressor().autoDecompress(this);

        if (!isMatrix() && !isVector())
            throw new IllegalArgumentException("Unable to get columns from a non matrix or vector");
        if (isVector())
            return Nd4j.pullRows(this, 1, rindices);
        else {
            INDArray ret = Nd4j.createUninitialized(this.dataType(), new long[] {rindices.length, columns()});
            for (int i = 0; i < rindices.length; i++)
                ret.putRow(i, getRow(rindices[i]));
            return ret;
        }
    }

    @Override
    public INDArray get(INDArrayIndex... indexes) {
        Nd4j.getCompressor().autoDecompress(this);

        int numPoint = 0;
        int numInterval = 0;
        int numAll = 0;
        int numNewAxis = 0;
        int numSpecified = 0;
        for(INDArrayIndex i : indexes){
            if(i instanceof PointIndex){
                numPoint++;
            } else if(i instanceof NDArrayIndexAll){
                numAll++;
            } else if(i instanceof IntervalIndex){
                numInterval++;
            } else if(i instanceof NewAxis){
                numNewAxis++;
            } else if(i instanceof SpecifiedIndex){
                numSpecified++;
            } else {
                throw new IllegalStateException("Unknown index: " + i);
            }
        }

        // Padding remaining dimensions with all() index if too few indices provided
        if (indexes.length - numNewAxis < this.rank()) {
            val newIndexes = new INDArrayIndex[this.rank() + numNewAxis];
            for (int e = 0; e < indexes.length; e++)
                newIndexes[e] = indexes[e];

            for (int e = indexes.length; e < newIndexes.length; e++) {
                numAll++;
                newIndexes[e] = NDArrayIndex.all();
            }

            indexes = newIndexes;
        }

        Preconditions.checkState((numPoint + numInterval + numAll + numSpecified) == rank(), "Illegal set of indices for array: need at least" +
                " %s point/interval/all/specified indices for rank %s array (%ndShape), got indices %s", rank(), rank(), this, indexes);

        int outRank = rank() + numNewAxis - numPoint;
        Preconditions.checkState(outRank >= 0, "Illegal set of indices for array: %ndShape, %s", this, indexes);


        //To work out sub-array, we need to work out 3 things: offset, shape and strides. We calculate all of these
        long[] outShape = new long[outRank];
        long[] outStrides = new long[outRank];
        long offset = offset();                     //Start with existing offset if view

        int outIdx = 0;     //Axis number counter for output array
        int inIdx = 0;      //Axis number counter for input array
        for( int i=0; i= size(inIdx)) {
                    throw new IllegalStateException("Indices are out of range: Cannot get interval index " + indexes[i] +
                            " on array with size(" + inIdx + ")=" + size(inIdx) + ". Array shape: " + Arrays.toString(shape()) +
                            ", indices: " + Arrays.toString(indexes));
                }
                long stride = ii.stride();
                long length = (endInc - start)/stride + 1;

                offset += ii.offset() * stride(inIdx);
                outShape[outIdx] = length;
                outStrides[outIdx] = ii.stride() * stride(inIdx);
                inIdx++;
                outIdx++;
            } else if(indexes[i] instanceof NewAxis) {
                //New axis: appends a 1 in shape. Axis not present in input, but is present in output
                outShape[outIdx] = 1;
                if (outIdx > 0) { //Stride doesn't matter for 1 size axis anyway...
                    outStrides[outIdx] = outStrides[outIdx - 1];
                } else {
                    outStrides[outIdx] = 1;
                }
                outIdx++;
            } else if(indexes[i] instanceof SpecifiedIndex){
                //Specified index: axis present in both input and output
                SpecifiedIndex si = (SpecifiedIndex)indexes[i];
                outShape[outIdx++] = si.length();
                inIdx++;
                //Don't care about strides for specified index, as result won't be a view
            } else {
                throw new IllegalStateException("Unknown index type: " + i);    //Should never happen
            }
        }


        //Note: If we have specified indices, we can't return a view. Instead, we copy the specified sub-arrays from
        // the input array to the output array.
        //How? Create the output array, then do loop over the specified indices only, and copy sub-arrays for all other axes
        if (numSpecified > 0) {
            INDArray out = Nd4j.create(dataType(), outShape);

            //Need to copy subsets here
            long[] specifiedSizes = new long[numSpecified];
            SpecifiedIndex[] si = new SpecifiedIndex[numSpecified];
            int j=0;
            for( int i=0; i replace with loop + point
            //    ii. new axis indices -> ignore/exclude (don't appear in input)
            //    iii. interval indices -> replace with all
            //(2) Get from output: requested indices, except for:
            //    i. point indices -> ignore/exclude (don't appear in output)
            //    ii. new axis indices -> replace with point(0)


            INDArrayIndex[] pointIdxsIn = new INDArrayIndex[indexes.length - numNewAxis];       //Indices for source (this array)
            int[] specifiedAxisIn = new int[numSpecified];
            int specCount = 0;
            j = 0;
            for( int i=0; i 0)
            throw new IllegalArgumentException("Illegal index for row: requested row " + r + " but this.size(0)=" + this.size(0));

        Preconditions.checkArgument(rank() == 2, "getRow() can be called on 2D arrays only");
        Preconditions.checkArgument(r < rows(), "Row index must be smaller than total number of rows");

        return tensorAlongDimension(r, 1);
    }

    @Override
    public INDArray getRow(long r, boolean keepDim) {
        INDArray row = getRow(r);
        if(!keepDim)
            return row;
        return row.reshape(1, row.length());
    }

    public boolean equalsWithEps(Object o, double eps) {
        Nd4j.getCompressor().autoDecompress(this);


        if (o == null)
            return false;

        if (!(o instanceof INDArray))
            return false;

        INDArray n = (INDArray) o;
        Nd4j.getCompressor().autoDecompress(n);

        if (n == this)
            return true;

        if (this.rank() != n.rank())
            return false;

        if (this.length() != n.length())
            return false;

        if (this.isEmpty() != n.isEmpty())
            return false;

        if (this.isEmpty() && n.isEmpty())
            return Shape.shapeEquals(this.shape(), n.shape());

        if (this.dataType() != n.dataType())
            return false;

        // meh
        if (this.dataType() == DataType.UTF8 && n.dataType() == DataType.UTF8) {
            for (long e = 0; e < this.length(); e++) {
                val str1 = this.getString(e);
                val str2 = n.getString(e);

                if (!str1.equals(str2))
                    return false;
            }

            return true;
        }

        //epsilon equals
        if (isScalar() && n.isScalar()) {
            if (isZ()) {
                val val = getLong(0);
                val val2 =  n.getLong(0);

                return val == val2;
            } else if (isR()) {
                val val = getDouble(0);
                val val2 = n.getDouble(0);

                if (Double.isNaN(val) != Double.isNaN(val2))
                    return false;

                return Math.abs(val - val2) < eps;
            } else if (isB()) {
                val val = getInt(0);
                val val2 =  n.getInt(0);

                return val == val2;
            }

        } else if (isVector() && n.isVector()) {
            val op = new EqualsWithEps(this, n, eps);
            Nd4j.exec(op);
            val diff = op.z().getDouble(0);

            return diff < 0.5;
        }

        if (!Arrays.equals(this.shape(), n.shape()))
            return false;


        if (!Shape.shapeEquals(shape(), n.shape())) {
            return false;
        }


        if (slices() != n.slices())
            return false;

        if (n.ordering() == ordering()) {
            EqualsWithEps op = new EqualsWithEps(this, n, eps);
            Nd4j.getExecutioner().exec(op);
            double diff = op.z().getDouble(0);

            return diff < 0.5;
        } else {
            EqualsWithEps op = new EqualsWithEps(this, n, eps);
            Nd4j.getExecutioner().exec(op);
            double diff = op.z().getDouble(0);

            return diff < 0.5;
        }
    }

    @Override
    public boolean equalShapes(@NonNull INDArray other){
        if(isEmpty() != other.isEmpty())
            return false;
        if(rank() != other.rank())
            return false;
        for( int i=0; i= rank())
            throw new IllegalArgumentException("Invalid size: cannot get size of dimension " + dimension + " for rank "
                    + rank() + " NDArray (array shape: " + Arrays.toString(this.shape()) + ")");

        return jvmShapeInfo.shape[dimension];
    }

    @Override
    public int rank() {
        return jvmShapeInfo.rank;
    }

    @Override
    public long length() {
        if (isEmpty())
            return 0;
        return jvmShapeInfo.length;
    }

    @Override
    public INDArray broadcast(INDArray result) {
        Nd4j.getCompressor().autoDecompress(this);

        val shape = result.shape();

        if (Shape.shapeEquals(shape, shape()))
            return this;

        // if we're on scalar, we can just create new array
        if (this.isScalar())
            return Nd4j.createUninitialized(this.dataType(), shape).assign(this.getDouble(0));




        boolean compatible = true;
        int count = shape.length - 1;
        int thisCount = jvmShapeInfo.rank - 1;
        for (int i = shape.length - 1; i > 0; i--) {
            if (count < 0 || thisCount < 0)
                break;
            if (shape[count] != shape()[thisCount] && shape[count] != 1 && shape()[thisCount] != 1) {
                compatible = false;
                break;
            }

            count--;
            thisCount--;
        }

        if (!compatible)
            throw new IllegalArgumentException("Incompatible broadcast from " + Arrays.toString(shape()) + " to "
                    + Arrays.toString(shape));



        long[] retShape = new long[shape.length];
        List broadCastDimensions = new ArrayList<>();
        List nonBroadCastDimensions = new ArrayList<>();
        for (int i = 0; i < retShape.length; i++) {
            if (shape().length == 1) {
                if (i == 0) {
                    if (i < shape().length)
                        retShape[i] = Math.max(1, shape[i]);
                    else
                        retShape[i] = shape[i];
                } else {
                    if (i < shape().length)
                        retShape[i] = Math.max(shape[i], size(i));
                    else
                        retShape[i] = shape[i];
                }
            } else {
                if (i < rank() && size(i) == 1)
                    broadCastDimensions.add(i);
                else
                    nonBroadCastDimensions.add(i);
                if (i < shape().length)
                    retShape[i] = Math.max(shape[i], size(i));
                else
                    retShape[i] = shape[i];
            }

        }


        if (isRowVector()) {
            //number of times to repeat each value
            for (int i = 0; i < result.slices(); i++) {
                result.putSlice(i, this);
            }
        } else if (isColumnVector()) {
            for (int i = 0; i < result.columns(); i++) {
                result.putColumn(i, this);
            }
        }

        else {
            int[] repeat = new int[shape.length];
            for(int i = 0; i < shape.length; i++) {
                if(i < rank()) {
                    if(size(i) == 1)
                        repeat[i] = (int) shape[i];
                    else {
                        repeat[i] = 1;
                    }
                }

                else {
                    repeat[i] = (int) shape[i];
                }
            }

            if (this.isView()) {
                Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this.dup(this.ordering())},new INDArray[]{result},repeat));
            } else
                Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this},new INDArray[]{result},repeat));
        }
        return result;

    }

    @Override
    public INDArray broadcast(long... shape) {
      return broadcast(Nd4j.createUninitialized(this.dataType(), shape, this.ordering()));
    }

    @Deprecated
    @Override
    public INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable) {
        return dimShuffle(rearrange, ArrayUtil.toLongArray(newOrder), broadCastable);
    }

    /**
     * Dimshuffle: an extension of permute that adds the ability
     * to broadcast various dimensions.
     * 

* See theano for more examples. * This will only accept integers and xs. *

* An x indicates a dimension should be broadcasted rather than permuted. * * @param rearrange the dimensions to swap to * @return the newly permuted array */ @Override public INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable) { Nd4j.getCompressor().autoDecompress(this); if (broadCastable.length != jvmShapeInfo.rank) throw new IllegalArgumentException( "The broadcastable dimensions must be the same length as the current shape"); boolean broadcast = false; Set set = new HashSet<>(); for (int i = 0; i < rearrange.length; i++) { set.add(rearrange[i]); if (rearrange[i] instanceof Integer) { Integer j = (Integer) rearrange[i]; if (j >= broadCastable.length) throw new IllegalArgumentException( "Illegal dimension, dimension must be < broadcastable.length (aka the real dimensions"); } else if (rearrange[i] instanceof Character) { Character c = (Character) rearrange[i]; if (c != 'x') throw new IllegalArgumentException("Illegal input: Must be x"); broadcast = true; } else throw new IllegalArgumentException("Only characters and integers allowed"); } //just do permute if (!broadcast) { int[] ret = new int[rearrange.length]; for (int i = 0; i < ret.length; i++) ret[i] = (Integer) rearrange[i]; return permute(ret); } else { List drop = new ArrayList<>(); for (int i = 0; i < broadCastable.length; i++) { if (!set.contains(i)) { if (broadCastable[i]) drop.add(i); else throw new IllegalArgumentException( "We can't drop the given dimension because its not broadcastable"); } } //list of dimensions to keep int[] shuffle = new int[broadCastable.length]; int count = 0; for (int i = 0; i < rearrange.length; i++) { if (rearrange[i] instanceof Integer) { shuffle[count++] = (Integer) rearrange[i]; } } List augment = new ArrayList<>(); for (int i = 0; i < rearrange.length; i++) { if (rearrange[i] instanceof Character) augment.add(i); } Integer[] augmentDims = augment.toArray(new Integer[1]); count = 0; int dropIdx = 0; int[] newShape = new int[shuffle.length + drop.size()]; for (int i = 0; i < newShape.length; i++) { if (i < shuffle.length) { newShape[count++] = shuffle[i]; } else newShape[count++] = drop.get(dropIdx++); } INDArray ret; //TODO is this correct? This was old behaviour before adding permute input check if(newShape.length == this.rank()){ ret = permute(newShape); } else { ret = dup(); } List newDims = new ArrayList<>(); long[] shape = Arrays.copyOfRange(ret.shape(), 0, shuffle.length); for (int i = 0; i < shape.length; i++) { newDims.add(shape[i]); } for (int i = 0; i < augmentDims.length; i++) { newDims.add(augmentDims[i], 1L); } long[] toReshape = ArrayUtil.toArrayLong(newDims); ret = ret.reshape(toReshape); return ret; } } @Override public INDArray permute(int... rearrange) { Preconditions.checkArgument(rearrange.length == rank(), "Incorrect number of arguments for permute function:" + " got arguments %s for rank %s array. Number of arguments must equal array rank", rearrange, rank()); Nd4j.getCompressor().autoDecompress(this); boolean alreadyInOrder = true; //IntBuffer shapeInfo = shapeInfo(); int rank = jvmShapeInfo.rank; for (int i = 0; i < rank; i++) { if (rearrange[i] != i) { alreadyInOrder = false; break; } } if (alreadyInOrder) return this; checkArrangeArray(rearrange); val newShape = doPermuteSwap(shape(), rearrange); val newStride = doPermuteSwap(stride(), rearrange); char newOrder = Shape.getOrder(newShape, newStride, 1); INDArray value = create(data(), newShape, newStride, offset(), newOrder); return value; } @Override public INDArray permutei(int... rearrange) { Preconditions.checkArgument(rearrange.length == rank(), "Incorrect number of arguments for permute function:" + " got arguments %s for rank %s array. Number of arguments must equal array rank", rearrange, rank()); boolean alreadyInOrder = true; val shapeInfo = shapeInfo(); int rank = jvmShapeInfo.rank; for (int i = 0; i < rank; i++) { if (rearrange[i] != i) { alreadyInOrder = false; break; } } if (alreadyInOrder) return this; checkArrangeArray(rearrange); val newShape = doPermuteSwap(shape(), rearrange); val newStride = doPermuteSwap(stride(), rearrange); char newOrder = Shape.getOrder(newShape, newStride, 1); val ews = shapeInfo.get(2 * rank + 2); val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty()); setShapeInformation(si); if (shapeInfo.get(2 * rank + 2) > 0) { //for the backend to work - no ews for permutei //^^ not true anymore? Not sure here. Marking this for raver setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, 0, newOrder, dataType(), isEmpty())); } //this.shape = null; //this.stride = null; return this; } @Deprecated protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) { val ret = new long[rearrange.length]; for (int i = 0; i < rearrange.length; i++) { ret[i] = shape.get(rearrange[i]); } return ret; } @Deprecated protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) { int[] ret = new int[rearrange.length]; for (int i = 0; i < rearrange.length; i++) { ret[i] = shape.get(rearrange[i]); } return ret; } @Deprecated protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) { int[] ret = new int[rearrange.length]; for (int i = 0; i < rearrange.length; i++) { ret[i] = shape.getInt(rearrange[i]); } return ret; } protected long[] doPermuteSwap(long[] shape, int[] rearrange) { val ret = new long[rearrange.length]; for (int i = 0; i < rearrange.length; i++) { ret[i] = shape[rearrange[i]]; } return ret; } protected void checkArrangeArray(int[] arr) { Preconditions.checkArgument(arr.length == jvmShapeInfo.rank, "Invalid rearrangement: number of arrangement (%s) != rank (%s)", arr.length, jvmShapeInfo.rank); for (int i = 0; i < arr.length; i++) { if (arr[i] >= arr.length) throw new IllegalArgumentException("The specified dimensions can't be swapped. Given element " + i + " was >= number of dimensions"); if (arr[i] < 0) throw new IllegalArgumentException("Invalid dimension: " + i + " : negative value"); } for (int i = 0; i < arr.length; i++) { for (int j = 0; j < arr.length; j++) { if (i != j && arr[i] == arr[j]) throw new IllegalArgumentException("Permute array must have unique elements"); } } } protected void autoProcessScalarCall() { /* if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC) OpProfiler.getInstance().processScalarCall();*/ } /** * Checks whether the matrix is a vector. */ @Override public boolean isVector() { if (jvmShapeInfo.rank == 1) return true; return isRowVector() || isColumnVector(); } @Override public boolean isVectorOrScalar() { return isVector() || isScalar(); } @Override public boolean isSquare() { return isMatrix() && rows() == columns(); } @Override public boolean isRowVector() { return (rank() == 2 && rows() == 1) && length() > 1 || rank() == 1 && length() > 1; } @Override public boolean isColumnVector() { return rank() == 2 && columns() == 1 && length() > 1; } @Override public boolean isColumnVectorOrScalar() { return isColumnVector() || isScalar(); } @Override public boolean isRowVectorOrScalar() { return isRowVector() || isScalar(); } /** * Generate string representation of the matrix. * Printing will switch to scientific notation on a per element basis * - when abs value is greater than or equal to 10000 * - when abs value is less than or equal to 0.0001 and not zero * * If the number of elements in the array is greater than 1000 (by default) only the first and last three elements * in a dimension are included. This can be changed globally using {@link NDArrayStrings#setMaxPrintElements(long)} * * */ @Override public String toString() { return toString(new NDArrayStrings()); } @Override public String toString(@NonNull NDArrayStrings options) { if(wasClosed()) return ""; if (!isCompressed() && !preventUnpack) return options.format(this); else if (isCompressed() && compressDebug) return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered."; else if (preventUnpack) return "Array string unpacking is disabled."; return options.format(this); } @Override public String toString(long maxElements, boolean forceSummarize, int precision){ return toString(new NDArrayStrings(maxElements, forceSummarize, precision)); } @Override public String toStringFull(){ return toString(Long.MAX_VALUE, false, -1 * dataType().precision()); } @Override public Object element() { if (!isScalar()) throw new IllegalStateException("Unable to retrieve element from non scalar matrix"); if (data.dataType() == DataType.FLOAT) return data.getFloat(0); return data.getDouble(0); } @Override public INDArray remainder(INDArray denominator) { if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { return remainder(denominator, Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(), denominator.shape()))); } else return remainder(denominator, this.ulike()); } @Override public INDArray remainder(INDArray denominator, INDArray result) { validateNumericalArray("remainder", false); Preconditions.checkArgument(Shape.areShapesBroadcastable(this.shape(), denominator.shape()),"Shapes must be broadcastable"); val op = new RemainderOp(this, denominator, result); Nd4j.getExecutioner().exec(op); return result; } @Override public INDArray remainder(Number denominator) { return remainder(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); } @Override public INDArray remainder(Number denominator, INDArray result) { validateNumericalArray("remainder", false); ScalarRemainder op = new ScalarRemainder(this, null, result, denominator); Nd4j.getExecutioner().exec(op); return result; } @Override public INDArray remainderi(INDArray denominator) { validateNumericalArray("remainderi", false); RemainderOp op = new RemainderOp(this, denominator, this); Nd4j.getExecutioner().exec(op); return this; } @Override public INDArray remainderi(Number denominator) { validateNumericalArray("remainderi", false); ScalarRemainder op = new ScalarRemainder(this, null, this, denominator); Nd4j.getExecutioner().exec(op); return this; } @Override public INDArray fmod(INDArray denominator) { validateNumericalArray("fmod", false); if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { return fmod(denominator, Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), Shape.broadcastOutputShape(this.shape(), denominator.shape()))); } else return fmod(denominator, this.ulike()); } @Override public INDArray fmod(INDArray denominator, INDArray result) { validateNumericalArray("fmod", false); if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { val outShape = Shape.broadcastOutputShape(this.shape(), denominator.shape()); Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); Nd4j.exec(new FloorModOp(new INDArray[]{this, denominator}, new INDArray[]{result})); return result; } else { FModOp op = new FModOp(this, denominator, result); Nd4j.getExecutioner().exec(op); return result; } } @Override public INDArray fmod(Number denominator) { return fmod(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); } @Override public INDArray fmod(Number denominator, INDArray result) { validateNumericalArray("fmod", false); ScalarFMod op = new ScalarFMod(this, null, result, denominator); Nd4j.getExecutioner().exec(op); return result; } @Override public INDArray fmodi(INDArray denominator) { validateNumericalArray("fmodi", false); FModOp op = new FModOp(this, denominator, this); Nd4j.getExecutioner().exec(op); return this; } @Override public INDArray fmodi(Number denominator) { validateNumericalArray("fmodi", false); ScalarFMod op = new ScalarFMod(this, null, this, denominator); Nd4j.getExecutioner().exec(op); return this; } @Override public Iterator iterator() { return new FirstAxisIterator(this); } @Override public long originalOffset() { if (data().originalOffset() >= Integer.MAX_VALUE) throw new IllegalArgumentException("Original offset of buffer can not be >= Integer.MAX_VALUE"); return data().originalOffset(); } private void readObject(ObjectInputStream s) { try { s.defaultReadObject(); read(s); } catch (Exception e) { throw new RuntimeException(e); } } private void writeObject(ObjectOutputStream out) throws IOException { out.defaultWriteObject(); write(out); } //Custom serialization for Java serialization protected void write(ObjectOutputStream out) throws IOException { if (this.isView()) { //As per Nd4j.write, duplicate before writing to the output stream //BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here //Furthermore, because we only want to save the *actual* data for a view (not the full data), the shape info // (mainly strides, offset, element-wise stride) may be different in the duped array vs. the view array INDArray copy = this.dup(); copy.shapeInfoDataBuffer().write(out); copy.data().write(out); } else { shapeInformation.write(out); data().write(out); } } //Custom deserialization for Java serialization protected void read(ObjectInputStream s) { val headerShape = BaseDataBuffer.readHeader(s); shapeInformation = Nd4j.createBuffer(new int[Shape.shapeInfoLength(rank())]); shapeInformation.read(s, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getRight()); setShapeInformation(Pair.create(shapeInformation, shapeInformation.asLong())); val headerData = BaseDataBuffer.readHeader(s); data = Nd4j.createBuffer(headerData.getRight(), headerData.getMiddle(), false); data().read(s, headerData.getLeft(), headerData.getMiddle(), headerData.getRight()); } @Override public INDArray argMax(int... dimension) { return Nd4j.argMax(this, dimension); } @Override public boolean isAttached() { if (isEmpty()) return false; Preconditions.checkArgument(!(data == null && !isEmpty()), "Array has no buffer!"); return data.isAttached() || (data.underlyingDataBuffer() != null && data.underlyingDataBuffer().isAttached()) || (data.originalDataBuffer() != null && data.originalDataBuffer().isAttached()); } @Override public boolean isInScope() { if (!isAttached()) return true; return data.isInScope(); } @Override public INDArray detach() { if (!isAttached()) return this; WorkspaceUtils.assertValidArray(this, "Cannot detach INDArray"); Nd4j.getExecutioner().commit(); /* two options here 1) we're within some workspace 2) we're out of any workspace */ if (Nd4j.getMemoryManager().getCurrentWorkspace() == null) { if (!isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); Nd4j.getMemoryManager().memcpy(buffer, this.data()); return Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); } else { INDArray copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); copy.assign(this); Nd4j.getExecutioner().commit(); return copy; } } else { MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); Nd4j.getMemoryManager().setCurrentWorkspace(null); INDArray copy = null; if (!isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); //Pointer.memcpy(buffer.pointer(), this.data.pointer(), this.lengthLong() * Nd4j.sizeOfDataType(this.data.dataType())); Nd4j.getMemoryManager().memcpy(buffer, this.data()); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); //this.dup(this.ordering()); } else { copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); copy.assign(this); Nd4j.getExecutioner().commit(); } Nd4j.getMemoryManager().setCurrentWorkspace(workspace); return copy; } } @Override public INDArray leverage() { WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); if (!isAttached()) return this; MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); if (workspace == null) { return this.detach(); } MemoryWorkspace parentWorkspace = workspace.getParentWorkspace(); if (this.data.getParentWorkspace() == parentWorkspace) return this; // if there's no parent ws - just detach if (parentWorkspace == null) return this.detach(); else { Nd4j.getExecutioner().commit(); // temporary set parent ws as current ws Nd4j.getMemoryManager().setCurrentWorkspace(parentWorkspace); INDArray copy = null; if (!this.isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.length(), false); Nd4j.getMemoryManager().memcpy(buffer, this.data()); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); } else { copy = this.dup(this.ordering()); Nd4j.getExecutioner().commit(); } // restore current ws Nd4j.getMemoryManager().setCurrentWorkspace(workspace); return copy; } } @Override public INDArray leverageTo(String id) { return leverageTo(id, false); } @Override public INDArray leverageTo(String id, boolean enforceExistence) throws Nd4jNoSuchWorkspaceException { WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); if (!isAttached()) return this; if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) { if(enforceExistence){ throw new Nd4jNoSuchWorkspaceException(id); } else { return this; } } MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id); if (this.data.getParentWorkspace() == target) return this; Nd4j.getMemoryManager().setCurrentWorkspace(target); INDArray copy = null; if (!this.isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); Nd4j.getMemoryManager().memcpy(buffer, this.data()); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); } else { copy = this.dup(this.ordering()); Nd4j.getExecutioner().commit(); } Nd4j.getMemoryManager().setCurrentWorkspace(current); return copy; } public INDArray leverageOrDetach(String id){ if(!isAttached()){ return this; } if(!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(id)){ return detach(); } return leverageTo(id); } @Override public INDArray migrate() { return migrate(false); } @Override public INDArray migrate(boolean detachOnNoWs){ WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); if (current == null) { if(detachOnNoWs){ return detach(); } else { return this; } } INDArray copy = null; if (!this.isView()) { Nd4j.getExecutioner().commit(); DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); Nd4j.getMemoryManager().memcpy(buffer, this.data()); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); } else { copy = this.dup(this.ordering()); Nd4j.getExecutioner().commit(); } return copy; } @Override public Number percentileNumber(Number quantile) { validateNumericalArray("percentileNumber", false); if (quantile.intValue() < 0 || quantile.intValue() > 100) throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); if (isScalar()) return this.getDouble(0); INDArray sorted = Nd4j.sort(this.dup(this.ordering()), true); return getPercentile(quantile, sorted); } @Override public Number medianNumber() { validateNumericalArray("medianNumber", false); if(isScalar()) return getNumber(0); return percentileNumber(50); } @Override public INDArray median(int... dimension) { validateNumericalArray("median", false); //Check edge case: size 1 element. No dimension == full array if(dimension.length == 0){ return Nd4j.scalar(dataType(), medianNumber().doubleValue()); } long shapeProd = 1; for (int d : dimension) { shapeProd *= size(d); } if (shapeProd == 1) { long[] newShape = ArrayUtil.removeIndex(shape(), dimension); return dup('c').reshape('c', newShape); } return percentile(50, dimension); } protected double getPercentile(Number quantile, INDArray sorted) { validateNumericalArray("getPercentile", false); if (quantile.intValue() == 0) return sorted.getDouble(0); else if (quantile.intValue() == 100) return sorted.getDouble(sorted.length() - 1); double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); if (pos < 1) return sorted.getDouble(0); else if (pos >= sorted.length()) return sorted.getDouble(sorted.length() - 1); double fposition = FastMath.floor(pos); int position = (int)fposition; double diff = pos - fposition; double lower = sorted.getDouble(position-1); double upper = sorted.getDouble(position); return lower + diff * (upper - lower); } @Override public INDArray percentile(Number quantile, int... dimension) { validateNumericalArray("percentile", false); if (quantile.doubleValue() < 0 || quantile.doubleValue() > 100) throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); if (isScalar()) return Nd4j.scalar(this.getDouble(0)); INDArray sorted = Nd4j.getNDArrayFactory().sort(this.dup(this.ordering()), false, dimension); // there's no practical sense doing this on GPU, stride will be just size of TAD. INDArray ret = Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), sorted.tensorsAlongDimension(dimension)); for (int i = 0; i < ret.length(); i++) { ret.putScalar(i, getPercentile(quantile, sorted.tensorAlongDimension(i, dimension))); } return ret; } protected abstract int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer); @Override public int toFlatArray(FlatBufferBuilder builder) { if(isView()){ return dup(this.ordering()).toFlatArray(builder); } int shape = FlatArray.createShapeVector(builder, this.shapeInfoDataBuffer().asLong()); int buffer = this.isEmpty() ? 0 : this.dataType() == DataType.UTF8 ? stringBuffer(builder, this.data()) : FlatArray.createBufferVector(builder, this.data().asBytes()); val type = this.isEmpty() ? FlatBuffersMapper.getDataTypeAsByte(this.dataType()) : FlatBuffersMapper.getDataTypeAsByte(this.data().dataType()); int array = FlatArray.createFlatArray(builder, shape, buffer, type, ByteOrder.BE); return array; } protected static DataTypeEx convertType(DataType type) { if (type == DataType.HALF) { return DataTypeEx.FLOAT16; } else if (type == DataType.FLOAT) { return DataTypeEx.FLOAT; } else if (type == DataType.DOUBLE) { return DataTypeEx.DOUBLE; } else if(type == DataType.INT) { return DataTypeEx.INT8; } else if(type == DataType.LONG) { return DataTypeEx.INT16; } else throw new IllegalStateException("Unknown dataType: [" + type + "]"); } @Override public boolean isEmpty() { return Shape.isEmpty(jvmShapeInfo.javaShapeInformation); } @Override public long[] shapeInfoJava() { return jvmShapeInfo.javaShapeInformation; } @Override public DataType dataType() { if (data != null) return data.dataType(); val e = Shape.extras(jvmShapeInfo.javaShapeInformation); if (e != 0) { val t = ArrayOptionsHelper.dataType(jvmShapeInfo.javaShapeInformation); if (t != DataType.UNKNOWN) return t; } return DataType.UNKNOWN; } @Override public boolean isR() { val dtype = dataType(); return dtype == DataType.FLOAT || dtype == DataType.DOUBLE || dtype == DataType.HALF || dtype == DataType.BFLOAT16; } @Override public boolean isZ() { return !isR() && !isB() && !isS(); } @Override public boolean isB() { return dataType() == DataType.BOOL; } @Override public boolean isS() { return dataType() == DataType.UTF8; } @Override public INDArray castTo(DataType dataType) { if(dataType == dataType()) //No-op if correct datatype return this; if(isEmpty() && rank() == 0){ return Nd4j.empty(dataType); } val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering()); result.assign(this); return result; } @Override public boolean all() { val r = Nd4j.getExecutioner().exec(new All(this)); return r.getDouble(0) != 0.0; } @Override public boolean any() { val r = Nd4j.getExecutioner().exec(new Any(this)); return r.getDouble(0) != 0.0; } @Override public boolean none() { return !any(); } /** * Validate that the operation is being applied on a numerical array (not boolean or utf8). * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays * @param opName Operation name to print in the exception */ protected void validateNumericalArray(String opName, boolean allowEmpty){ if(dataType() == DataType.BOOL || dataType() == DataType.UTF8) throw new IllegalStateException("Cannot apply operation " + opName + " to array with " + dataType() + " datatype. Array shape: " + Arrays.toString(shape())); if(!allowEmpty && isEmpty()) throw new IllegalStateException("Cannot perform operation " + opName + " on empty array with datatype " + dataType()); } @Override public boolean closeable() { if (released || isAttached()) return false; // empty arrays have no buffer at all if (isEmpty()) return true; if (isView()) return false; return data.closeable(); } @Override public void close() { // empty arrays have no buffer at all if (released || isEmpty()) return; Nd4j.getExecutioner().commit(); if (!closeable()) throw new ND4JIllegalStateException("Can't release this INDArray"); data.close(); released = true; } @Override public INDArray like() { return Nd4j.create(this.dataType(), this.shape(), Nd4j.getStrides(this.shape(), this.ordering()), this.ordering()); } @Override public INDArray ulike() { return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); } @Override public boolean wasClosed() { // data can be null if that's empty array if (released || (data() != null && data().wasClosed())) return true; return false; } @Override public long getId(){ return arrayId; } }