ai.djl.ndarray.NDArray Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray;
import ai.djl.Device;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.internal.NDFormat;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.util.Float16Utils;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
/**
* An interface representing an n-dimensional array.
*
* NDArray is the core data structure for all mathematical computations. An NDArray represents a
* multidimensional, fixed-size homogeneous array. It has very similar behaviour to the Numpy python
* package with the addition of efficient computing. To understand how to manage NDArray lifecycle,
* please refer to NDArray
* Memory Management Guide
*/
public interface NDArray extends NDResource, BytesSupplier {
/**
* Decodes {@code NDArray} from bytes.
*
* @param manager {@link NDManager} used to create this {@code NDArray}
* @param byteArray data used to decode
* @return decoded {@code NDArray}
*/
static NDArray decode(NDManager manager, byte[] byteArray) {
return manager.decode(byteArray);
}
/**
* Returns the name of this {@code NDArray}.
*
* @return the name of this {@code NDArray}
*/
String getName();
/**
* Sets name of this {@code NDArray}.
*
* @param name the name of this {@code NDArray}
*/
void setName(String name);
/**
* Returns unique identifier of this {@code NDArray}.
*
* @return unique identifier of this {@code NDArray}
*/
String getUid();
/**
* Returns the {@link DataType} of this {@code NDArray}.
*
*
{@link DataType} is a definition of the precision level of the {@code NDArray}. All values
* inside the same {@code NDArray} would have the same {@link DataType}.
*
* @return the {@link DataType} of this {@code NDArray}
*/
DataType getDataType();
/**
* Returns the {@link Device} of this {@code NDArray}.
*
*
{@link Device} class contains the information where this {@code NDArray} stored in memory,
* like CPU/GPU.
*
* @return the {@link Device} of this {@code NDArray}
*/
Device getDevice();
/**
* Returns the {@link Shape} of this {@code NDArray}.
*
*
{@link Shape} defines how this {@code NDArray} is represented multi-dimensionally.
*
* @return the {@link Shape} of this {@code NDArray}
*/
Shape getShape();
/**
* Returns the {@link SparseFormat} of this {@code NDArray}.
*
* @return the {@link SparseFormat} of this {@code NDArray}
*/
SparseFormat getSparseFormat();
/**
* Returns {@code true} if this {@code NDArray} is a {@link SparseNDArray}.
*
* @return {@code true} if this {@code NDArray} is a {@link SparseNDArray}
*/
default boolean isSparse() {
return getSparseFormat() != SparseFormat.DENSE;
}
/**
* Returns {@code true} if this {@code NDArray} is a scalar {@code NDArray} with empty {@link
* Shape}.
*
* @return {@code true} if this {@code NDArray} is a scalar {@code NDArray} with empty {@link
* Shape}
*/
default boolean isScalar() {
return getShape().isScalar();
}
/**
* Encodes {@code NDArray} to byte array.
*
* @return byte array
*/
default byte[] encode() {
return NDSerializer.encode(this);
}
/**
* Moves this {@code NDArray} to a different {@link Device}.
*
* @param device the {@link Device} to be set
* @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray}
* @return the result {@code NDArray} with the new {@link Device}
*/
NDArray toDevice(Device device, boolean copy);
/**
* Converts this {@code NDArray} to a different {@link DataType}.
*
* @param dataType the {@link DataType} to be set
* @param copy set {@code true} if you want to return a copy of the Existing {@code NDArray}
* @return the result {@code NDArray} with the new {@link DataType}
*/
NDArray toType(DataType dataType, boolean copy);
/**
* Attaches a gradient {@code NDArray} to this {@code NDArray} and marks it so {@link
* ai.djl.training.GradientCollector#backward(NDArray)} can compute the gradient with respect to
* it.
*
* @param requiresGrad if {@code NDArray} requires gradient or not
*/
void setRequiresGradient(boolean requiresGrad);
/**
* Returns the gradient {@code NDArray} attached to this {@code NDArray}.
*
* @return the gradient {@code NDArray}
* @throws NullPointerException when gradient is not initialized
*/
NDArray getGradient();
/**
* Returns true if the gradient calculation is required for this {@code NDArray}.
*
* @return true if the gradient calculation is required for this {@code NDArray} else false
*/
boolean hasGradient();
/**
* Returns an NDArray equal to this that stop gradient propagation through it.
*
* @return an NDArray equal to this that stops gradient propagation through it
*/
NDArray stopGradient();
/**
* Returns an NDArray equal to this that magnifies the gradient propagated to this by a
* constant.
*
* @param scale how to much to magnify the gradient propagated to this
* @return an NDArray equal to this that magnifies the gradient propagated to this by a constant
*/
default NDArray scaleGradient(double scale) {
return this.mul(scale).add(this.stopGradient().mul(1 - scale));
}
/**
* Returns the size of this {@code NDArray} along a given axis.
*
* @param axis the axis to return the size for
* @return the size of this {@code NDArray} along a given axis
*/
default long size(int axis) {
return getShape().size(axis);
}
/**
* Returns the total number of elements in this {@code NDArray}.
*
* @return the number of elements in this {@code NDArray}
*/
default long size() {
return getShape().size();
}
/** {@inheritDoc} */
@Override
default ByteBuffer toByteBuffer() {
return toByteBuffer(false);
}
/**
* Returns the {@code ByteBuffer} presentation of the object.
*
*
If returned ByteBuffer is a DirectByteBuffer, it shared the same native memory as the
* NDArray. The native memory will be deleted when NDArray is closed.
*
*
Not all the engine support return DirectByteBuffer.
*
* @param tryDirect use DirectBuffer if possible
* @return the {@code ByteBuffer} presentation of the object
*/
ByteBuffer toByteBuffer(boolean tryDirect);
/**
* Converts this {@code NDArray} to a double array.
*
* @return a double array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default double[] toDoubleArray() {
if (getDataType() != DataType.FLOAT64) {
throw new IllegalStateException(
"DataType mismatch, Required double" + " Actual " + getDataType());
}
DoubleBuffer db = toByteBuffer(true).asDoubleBuffer();
double[] ret = new double[db.remaining()];
db.get(ret);
return ret;
}
/**
* Converts this {@code NDArray} to a float array.
*
* @return a float array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default float[] toFloatArray() {
if (getDataType() == DataType.FLOAT16) {
return Float16Utils.fromByteBuffer(toByteBuffer());
} else if (getDataType() != DataType.FLOAT32) {
throw new IllegalStateException(
"DataType mismatch, Required float, Actual " + getDataType());
}
FloatBuffer fb = toByteBuffer(true).asFloatBuffer();
float[] ret = new float[fb.remaining()];
fb.get(ret);
return ret;
}
/**
* Converts this {@code NDArray} to an short array.
*
* @return an int array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default short[] toShortArray() {
if (getDataType() != DataType.INT16) {
throw new IllegalStateException(
"DataType mismatch, Required int" + " Actual " + getDataType());
}
ShortBuffer ib = toByteBuffer(true).asShortBuffer();
short[] ret = new short[ib.remaining()];
ib.get(ret);
return ret;
}
/**
* Converts this {@code NDArray} to an short array.
*
* @return an int array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default int[] toUnsignedShortArray() {
if (getDataType() != DataType.UINT16) {
throw new IllegalStateException(
"DataType mismatch, Required int" + " Actual " + getDataType());
}
ShortBuffer ib = toByteBuffer(true).asShortBuffer();
int[] ret = new int[ib.remaining()];
for (int i = 0; i < ret.length; ++i) {
ret[i] = ib.get() & 0xffff;
}
return ret;
}
/**
* Converts this {@code NDArray} to an int array.
*
* @return an int array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default int[] toIntArray() {
DataType dType = getDataType();
if (dType != DataType.INT32 && dType != DataType.UINT32) {
throw new IllegalStateException(
"DataType mismatch, Required int" + " Actual " + getDataType());
}
IntBuffer ib = toByteBuffer(true).asIntBuffer();
int[] ret = new int[ib.remaining()];
ib.get(ret);
return ret;
}
/**
* Converts this {@code NDArray} to an unsigned int array.
*
* @return a long array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default long[] toUnsignedIntArray() {
if (getDataType() != DataType.UINT32) {
throw new IllegalStateException(
"DataType mismatch, Required int" + " Actual " + getDataType());
}
IntBuffer ib = toByteBuffer(true).asIntBuffer();
long[] ret = new long[ib.remaining()];
for (int i = 0; i < ret.length; ++i) {
ret[i] = ib.get() & 0X00000000FFFFFFFFL;
}
return ret;
}
/**
* Converts this {@code NDArray} to a long array.
*
* @return a long array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default long[] toLongArray() {
if (getDataType() != DataType.INT64) {
throw new IllegalStateException(
"DataType mismatch, Required long" + " Actual " + getDataType());
}
LongBuffer lb = toByteBuffer(true).asLongBuffer();
long[] ret = new long[lb.remaining()];
lb.get(ret);
return ret;
}
/**
* Converts this {@code NDArray} to a byte array.
*
* @return a byte array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default byte[] toByteArray() {
ByteBuffer bb = toByteBuffer(true);
if (bb.hasArray() && bb.remaining() == bb.array().length) {
return bb.array();
}
byte[] buf = new byte[bb.remaining()];
bb.get(buf);
return buf;
}
/**
* Converts this {@code NDArray} to a uint8 array.
*
* @return a uint8 array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default int[] toUint8Array() {
ByteBuffer bb = toByteBuffer(true);
int[] buf = new int[bb.remaining()];
for (int i = 0; i < buf.length; ++i) {
buf[i] = bb.get() & 0xff;
}
return buf;
}
/**
* Converts this {@code NDArray} to a boolean array.
*
* @return a boolean array
* @throws IllegalStateException when {@link DataType} of this {@code NDArray} mismatches
*/
default boolean[] toBooleanArray() {
if (getDataType() != DataType.BOOLEAN) {
throw new IllegalStateException(
"DataType mismatch, Required boolean" + " Actual " + getDataType());
}
ByteBuffer bb = toByteBuffer(true);
boolean[] ret = new boolean[bb.remaining()];
for (int i = 0; i < ret.length; ++i) {
ret[i] = bb.get() != 0;
}
return ret;
}
/**
* Converts this {@code NDArray} to a String array.
*
*
This method is only applicable to the String typed NDArray and not for printing purpose
*
* @return Array of Strings
*/
default String[] toStringArray() {
return toStringArray(StandardCharsets.UTF_8);
}
/**
* Converts this {@code NDArray} to a String array with the specified charset.
*
*
This method is only applicable to the String typed NDArray and not for printing purpose
*
* @param charset to charset for the string
* @return Array of Strings
*/
String[] toStringArray(Charset charset);
/**
* Converts this {@code NDArray} to a Number array based on its {@link DataType}.
*
* @return a Number array
*/
@SuppressWarnings("PMD.AvoidArrayLoops")
default Number[] toArray() {
switch (getDataType()) {
case FLOAT16:
case FLOAT32:
float[] floatArray = toFloatArray();
return IntStream.range(0, floatArray.length)
.mapToObj(i -> floatArray[i])
.toArray(Number[]::new);
case FLOAT64:
return Arrays.stream(toDoubleArray()).boxed().toArray(Double[]::new);
case INT16:
short[] buf = toShortArray();
Short[] sbuf = new Short[buf.length];
for (int i = 0; i < buf.length; ++i) {
sbuf[i] = buf[i];
}
return sbuf;
case UINT16:
return Arrays.stream(toUnsignedShortArray()).boxed().toArray(Integer[]::new);
case INT32:
return Arrays.stream(toIntArray()).boxed().toArray(Integer[]::new);
case UINT32:
return Arrays.stream(toUnsignedIntArray()).boxed().toArray(Long[]::new);
case INT64:
case UINT64:
return Arrays.stream(toLongArray()).boxed().toArray(Long[]::new);
case BOOLEAN:
case INT8:
ByteBuffer bb = toByteBuffer();
Byte[] ret = new Byte[bb.remaining()];
for (int i = 0; i < ret.length; ++i) {
ret[i] = bb.get();
}
return ret;
case UINT8:
return Arrays.stream(toUint8Array()).boxed().toArray(Integer[]::new);
default:
throw new IllegalStateException("Unsupported DataType: " + getDataType());
}
}
/**
* Sets this {@code NDArray} value from {@link Buffer}.
*
* @param buffer the input buffered data
*/
void set(Buffer buffer);
/**
* Sets this {@code NDArray} value from an array of floats.
*
* @param data the array of floats to set
*/
default void set(float[] data) {
set(FloatBuffer.wrap(data));
}
/**
* Sets this {@code NDArray} value from an array of ints.
*
* @param data the array of integers to set
*/
default void set(int[] data) {
set(IntBuffer.wrap(data));
}
/**
* Sets this {@code NDArray} value from an array of doubles.
*
* @param data the array of doubles to set
*/
default void set(double[] data) {
set(DoubleBuffer.wrap(data));
}
/**
* Sets this {@code NDArray} value from an array of longs.
*
* @param data the array of longs to set
*/
default void set(long[] data) {
set(LongBuffer.wrap(data));
}
/**
* Sets this {@code NDArray} value from an array of bytes.
*
* @param data the array of bytes to set
*/
default void set(byte[] data) {
set(ByteBuffer.wrap(data));
}
/**
* Sets the specified index in this {@code NDArray} with the given values.
*
* @param index the locations to update
* @param value the value to replace with. Can broadcast if given smaller dimensions than the
* index
*/
default void set(NDIndex index, NDArray value) {
getNDArrayInternal().getIndexer(getManager()).set(this, index, value);
}
/**
* Sets the specified index in this {@code NDArray} with the given value.
*
* @param index the locations to update
* @param value the value to replace with
*/
default void set(NDIndex index, Number value) {
getNDArrayInternal().getIndexer(getManager()).set(this, index, value);
}
/**
* Sets the specific index by a function.
*
* @param index the locations to update
* @param function the function to change the value
*/
default void set(NDIndex index, Function function) {
NDArray array = get(index);
set(index, function.apply(array));
}
/**
* Sets the {@code NDArray} by boolean mask or integer index.
*
* @param index the boolean or integer {@code NDArray} that indicates what to get
* @param value the value to replace with
*/
default void set(NDArray index, Number value) {
set(new NDIndex("{}", index), value);
}
/**
* Sets the specified scalar in this {@code NDArray} with the given value.
*
* @param index the single index to update
* @param value the value to replace with
* @throws IllegalArgumentException thrown if the index does not correspond to a single element
*/
default void setScalar(NDIndex index, Number value) {
getNDArrayInternal().getIndexer(getManager()).setScalar(this, index, value);
}
/**
* Returns a partial {@code NDArray}.
*
* @param index the section of this {@code NDArray} to return
* @return the partial {@code NDArray}
*/
default NDArray get(NDIndex index) {
return get(getManager(), index);
}
/**
* Returns a partial {@code NDArray}.
*
* @param manager the manager used to create the arrays
* @param index the section of this {@code NDArray} to return
* @return the partial {@code NDArray}
*/
default NDArray get(NDManager manager, NDIndex index) {
return getNDArrayInternal().getIndexer(manager).get(this, index);
}
/**
* Returns a partial {@code NDArray}.
*
* @param index the boolean or integer {@code NDArray} that indicates what to get
* @return the partial {@code NDArray}
*/
default NDArray get(NDArray index) {
return get(new NDIndex("{}", index));
}
/**
* Returns a partial {@code NDArray}.
*
* @param indices the indices used to indicate what to get
* @param args arguments to replace the varaible "{}" in the indices string. Can be an integer,
* long, boolean {@link NDArray}, or integer {@link NDArray}.
* @return the partial {@code NDArray}
* @see NDIndex#NDIndex(String, Object...)
*/
default NDArray get(String indices, Object... args) {
return get(new NDIndex(indices, args));
}
/**
* Returns a partial {@code NDArray}.
*
* @param indices the indices with each index corresponding to the dimensions and negative
* indices starting from the end
* @return the partial {@code NDArray}
*/
default NDArray get(long... indices) {
return get(new NDIndex(indices));
}
/**
* Returns a partial {@code NDArray}.
*
* @param manager the manager used to create the arrays
* @param indices the indices with each index corresponding to the dimensions and negative
* indices starting from the end
* @return the partial {@code NDArray}
*/
default NDArray get(NDManager manager, long... indices) {
return get(manager, new NDIndex(indices));
}
/**
* Returns a partial {@code NDArray} pointed by the indexed array.
*
*
* out[i][j][k] = input[index[i][j][k]][j][k] # if axis == 0
* out[i][j][k] = input[i][index[i][j][k]][k] # if axis == 1
* out[i][j][k] = input[i][j][index[i][j][k]] # if axis == 2
*
*
* @param index picks the elements of an NDArray to the same position as index
* @param axis the entries of index are indices of axis
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray gather(NDArray index, int axis);
/**
* Returns a partial {@code NDArray} pointed by the indexed array.
*
*
* Given NDArray arr and NDArray idx. idx is the following structure:
* \( idx = [ idx[0, ...], idx[1, ...],..., idx[indexingDepth,...] ] \)
* corresponding to x, y, z index, i.e. [idx_x, idx_y, idx_z, ...].
*
*
* So indexingDepth smaller than or equal to data.shape[0] If indexingDepth is smaller than
* data.shape[0], for instance, data.shape[0]=3, i.e. x,y,z but indexingDepth = 2, i.e. [idx_x,
* idx_y], then the tail co-rank = data.shape[0] - indexingDepth will be kept.
*
*
With it, the output shape = idx_y.shape appended by data.shape[indexingDepth:] mx.symbol.gather_nd
*
* @param index picks the elements of an NDArray to the same position as index
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray gatherNd(NDArray index);
/**
* Returns a partial {@code NDArray} pointed by index according to linear indexing, and the of
* output is of the same shape as index.
*
* @param index picks the elements of an NDArray and output to the same entry as in index
* @return the partial {@code NDArray} of the same shape as index
*/
default NDArray take(NDArray index) {
return take(this.getManager(), index);
}
/**
* Returns a partial {@code NDArray} pointed by index according to linear indexing, and the of
* output is of the same shape as index.
*
* @param manager the manager used to create the arrays
* @param index picks the elements of an NDArray and output to the same entry as in index
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray take(NDManager manager, NDArray index);
/**
* Sets the entries of {@code NDArray} pointed by index, according to linear indexing, to be the
* numbers in value.
*
*
Value has to be of the same shape as index.
*
* @param index select the entries of an {@code NDArray}
* @param value numbers to assign to the indexed entries
* @return the NDArray with updated values
*/
NDArray put(NDArray index, NDArray value);
/**
* Writes all values from the tensor value into self at the indices specified in the index
* tensor.
*
*
* This is the reverse operation of the manner described in gather().
*
* self[index[i][j][k]][j][k] = value[i][j][k] # if axis == 0
* self[i][index[i][j][k]][k] = value[i][j][k] # if axis == 1
* self[i][j][index[i][j][k]] = value[i][j][k] # if axis == 2
*
*
* torch.Tensor.scatter_
*
* @param axis the axis along which to index
* @param index the indices of elements to scatter, can be either empty or of the same
* dimensionality as value. When empty, the operation returns self unchanged
* @param value the source element(s) to scatter
* @return the NDArray with updated values
*/
NDArray scatter(NDArray index, NDArray value, int axis);
/**
* Returns a scalar {@code NDArray} corresponding to a single element.
*
* @param indices the indices of the scalar to return. Must return only a single element
* @return a scalar {@code NDArray} corresponding to the element
* @throws IllegalArgumentException thrown if the result is not a single element
*/
default NDArray getScalar(long... indices) {
NDArray value = get(new NDIndex(indices));
if (value.size() != 1) {
throw new IllegalArgumentException("The supplied Index does not produce a scalar");
}
return value;
}
/**
* Returns a long element from this {@code NDArray}.
*
* @param indices the indices of the long element to return
* @return the element in the specified index as a long
* @throws IllegalArgumentException thrown if the result is not a single element
*/
default long getLong(long... indices) {
try (NDArray scalar = getScalar(indices)) {
return scalar.toLongArray()[0];
}
}
/**
* Returns a double element from this {@code NDArray}.
*
* @param indices the indices of the double element to return
* @return the element in the specified index as a double
* @throws IllegalArgumentException thrown if the result is not a single element
*/
default double getDouble(long... indices) {
try (NDArray scalar = getScalar(indices)) {
return scalar.toDoubleArray()[0];
}
}
/**
* Returns a float element from this {@code NDArray}.
*
* @param indices the indices of the long element to return
* @return the element in the specified index as a float
* @throws IllegalArgumentException thrown if the result is not a single element
*/
default float getFloat(long... indices) {
try (NDArray scalar = getScalar(indices)) {
return scalar.toFloatArray()[0];
}
}
/**
* Returns an int element from this {@code NDArray}.
*
* @param indices the indices of the int element to return
* @return the element in the specified index as an integer
* @throws IllegalArgumentException thrown if the result is not a single element
*/
default int getInt(long... indices) {
try (NDArray scalar = getScalar(indices)) {
return scalar.toIntArray()[0];
}
}
/**
* Returns an byte element from this {@code NDArray}.
*
* @param indices the indices of the byte element to return
* @return the element in the specified index as a byte
* @throws IllegalArgumentException thrown if the result is not a single element
*/
default byte getByte(long... indices) {
try (NDArray scalar = getScalar(indices)) {
return scalar.toByteArray()[0];
}
}
/**
* Returns an integer element from this {@code NDArray} that represent unsigned integer with 8
* bits.
*
* @param indices the indices of the unsigned 8 bits integer element to return
* @return the element in the specified index as a uint8
* @throws IllegalArgumentException thrown if the result is not a single element
*/
default int getUint8(long... indices) {
return getByte(indices) & 0xff;
}
/**
* Returns a boolean element from this {@code NDArray}.
*
* @param indices the indices of the int element to return
* @return the element in the specified index as a boolean
* @throws IllegalArgumentException thrown if the result is not a single element
*/
default boolean getBoolean(long... indices) {
try (NDArray scalar = getScalar(indices)) {
return scalar.toBooleanArray()[0];
}
}
/**
* Deep-copies the current {@code NDArray} to the one passed in.
*
* @param array this {@code NDArray} prepared to be copied to
*/
default void copyTo(NDArray array) {
array.set(toByteBuffer());
}
/**
* Returns a copy of this {@code NDArray}.
*
* @return a copy of this {@code NDArray}
*/
default NDArray duplicate() {
NDArray array = getManager().create(getShape(), getDataType(), getDevice());
array.setName(getName());
copyTo(array);
return array;
}
/**
* Returns portion of this {@code NDArray} given the index boolean {@code NDArray} along first
* axis.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f}, new Shape(3, 2));
* jshell> NDArray mask = manager.create(new boolean[] {true, false, true});
* jshell> array.booleanMask(mask);
* ND: (2, 2) cpu() float32
* [[1., 2.],
* [5., 6.],
* ]
*
*
* @param index boolean {@code NDArray} mask
* @return the result {@code NDArray}
*/
default NDArray booleanMask(NDArray index) {
return booleanMask(index, 0);
}
/**
* Returns portion of this {@code NDArray} given the index boolean {@code NDArray} along given
* axis.
*
* @param index boolean {@code NDArray} mask
* @param axis an integer that represents the axis of {@code NDArray} to mask from
* @return the result {@code NDArray}
*/
NDArray booleanMask(NDArray index, int axis);
/**
* Sets all elements outside the sequence to a constant value.
*
* This function takes an n-dimensional input array of the form [batch_size,
* max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
* sequenceLength} is used to handle variable-length sequences. sequence_length should be an
* input array of positive ints of dimension [batch_size].
*
* @param sequenceLength used to handle variable-length sequences
* @param value the constant value to be set
* @return the result {@code NDArray}
*/
NDArray sequenceMask(NDArray sequenceLength, float value);
/**
* Sets all elements outside the sequence to 0.
*
*
This function takes an n-dimensional input array of the form [batch_size,
* max_sequence_length, ....] and returns an array of the same shape. Parameter {@code
* sequenceLength} is used to handle variable-length sequences. sequence_length should be an
* input array of positive ints of dimension [batch_size].
*
* @param sequenceLength used to handle variable-length sequences
* @return the result {@code NDArray}
*/
NDArray sequenceMask(NDArray sequenceLength);
/**
* Returns an {@code NDArray} of zeros with the same {@link Shape}, {@link DataType} and {@link
* SparseFormat} as the input {@code NDArray}.
*
*
Examples
*
*
* jshell> NDArray array = manager.arange(6f).reshape(2, 3);
* jshell> array;
* ND: (2, 3) cpu() float32
* [[0., 1., 2.],
* [3., 4., 5.],
* ]
* jshell> array.zerosLike();
* ND: (2, 3) cpu() float32
* [[0., 0., 0.],
* [0., 0., 0.],
* ]
*
*
* @return a {@code NDArray} filled with zeros
*/
default NDArray zerosLike() {
return getManager().zeros(getShape(), getDataType(), getDevice());
}
/**
* Returns an {@code NDArray} of ones with the same {@link Shape}, {@link DataType} and {@link
* SparseFormat} as the input {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(6f).reshape(2, 3);
* jshell> array;
* ND: (2, 3) cpu() float32
* [[0., 1., 2.],
* [3., 4., 5.],
* ]
* jshell> array.onesLike();
* ND: (2, 3) cpu() float32
* [[1., 1., 1.],
* [1., 1., 1.],
* ]
*
*
* @return a {@code NDArray} filled with ones
*/
default NDArray onesLike() {
return getManager().ones(getShape(), getDataType(), getDevice());
}
/**
* Returns an uninitialized {@code NDArray} with the same {@link Shape}, {@link DataType} and
* {@link SparseFormat} as the input {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(6f).reshape(2, 3);
* jshell> array;
* ND: (2, 3) cpu() float32
* [[0., 1., 2.],
* [3., 4., 5.],
* ]
* jshell> array.like(); // uninitialized NDArray
* ND: (2, 3) cpu() float32
* [[ 9.80908925e-45, 0.00000000e+00, 0.00000000e+00],
* [ 0.00000000e+00, 7.61595174e-07, 2.80259693e-44],
* ]
*
*
* @return the result {@code NDArray}
*/
default NDArray like() {
return getManager().create(getShape());
}
////////////////////////////////////////
////////////////////////////////////////
// Operations
////////////////////////////////////////
////////////////////////////////////////
////////////////////////////////////////
// Operations: Element Comparison
////////////////////////////////////////
/**
* Returns {@code true} if all elements in this {@code NDArray} are equal to the {@link Number}.
*
* Examples
*
*
* jshell> NDArray array = manager.ones(new Shape(2, 3));
* jshell> array.contentEquals(1); // return true instead of boolean NDArray
* true
*
*
* @param number the number to compare
* @return the boolean result
*/
boolean contentEquals(Number number);
/**
* Returns {@code true} if all elements in this {@code NDArray} are equal to the other {@link
* NDArray}.
*
* Examples
*
*
* jshell> NDArray array1 = manager.arange(6f).reshape(2, 3);
* jshell> NDArray array2 = manager.create(new float[] {0f, 1f, 2f, 3f, 4f, 5f}, new Shape(2, 3));
* jshell> array1.contentEquals(array2); // return true instead of boolean NDArray
* true
*
*
* @param other the other {@code NDArray} to compare
* @return the boolean result
*/
boolean contentEquals(NDArray other);
/**
* Checks 2 {@code NDArray}s for equal shapes.
*
* Shapes are considered equal if:
*
*
* - Both {@code NDArray}s have equal rank, and
*
- size(0)...size(rank()-1) are equal for both {@code NDArray}s
*
*
* Examples
*
*
* jshell> NDArray array1 = manager.ones(new Shape(1, 2, 3));
* jshell> NDArray array2 = manager.create(new Shape(1, 2, 3));
* jshell> array1.shapeEquals(array2); // return true instead of boolean NDArray
* true
*
*
* @param other the other {@code NDArray}
* @return {@code true} if the {@link Shape}s are the same
*/
default boolean shapeEquals(NDArray other) {
return getShape().equals(other.getShape());
}
/**
* Returns {@code true} if two {@code NDArray}s are element-wise equal within a tolerance.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new double[] {1e10, 1e-7});
* jshell> NDArray array2 = manager.create(new double[] {1.00001e10, 1e-8});
* jshell> array1.allClose(array2); // return false instead of boolean NDArray
* false
* jshell> NDArray array1 = manager.create(new double[] {1e10, 1e-8});
* jshell> NDArray array2 = manager.create(new double[] {1.00001e10, 1e-9});
* jshell> array1.allClose(array2); // return true instead of boolean NDArray
* true
*
*
* @param other the {@code NDArray} to compare with
* @return the boolean result
*/
default boolean allClose(NDArray other) {
return allClose(other, 1e-5, 1e-08, false);
}
/**
* Returns {@code true} if two {@code NDArray} are element-wise equal within a tolerance.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new double[] {1e10, 1e-7});
* jshell> NDArray array2 = manager.create(new double[] {1.00001e10, 1e-8});
* jshell> array1.allClose(array2, 1e-05, 1e-08, false); // return false instead of boolean NDArray
* false
* jshell> NDArray array1 = manager.create(new double[] {1e10, 1e-8});
* jshell> NDArray array2 = manager.create(new double[] {1.00001e10, 1e-9});
* jshell> array1.allClose(array2, 1e-05, 1e-08, false); // return true instead of boolean NDArray
* true
* jshell> NDArray array1 = manager.create(new float[] {1f, Float.NaN});
* jshell> NDArray array2 = manager.create(new float[] {1f, Float.NaN});
* jshell> array1.allClose(array2, 1e-05, 1e-08, true); // return true instead of boolean NDArray
* true
*
*
* @param other the {@code NDArray} to compare with
* @param rtol the relative tolerance parameter
* @param atol the absolute tolerance parameter
* @param equalNan whether to compare NaN’s as equal. If {@code true}, NaN’s in the {@link
* NDArray} will be considered equal to NaN’s in the other {@code NDArray}
* @return the boolean result
*/
default boolean allClose(NDArray other, double rtol, double atol, boolean equalNan) {
if (!shapeEquals(other)) {
return false;
}
Number[] actualDoubleArray = toArray();
Number[] expectedDoubleArray = other.toArray();
for (int i = 0; i < actualDoubleArray.length; i++) {
double a = actualDoubleArray[i].doubleValue();
double b = expectedDoubleArray[i].doubleValue();
// handle NaN
if (equalNan && Double.isNaN(a) && Double.isNaN(b)) {
continue;
}
if (Double.isNaN(a)
|| Double.isNaN(b)
|| (Math.abs(a - b) > (atol + rtol * Math.abs(b)))) {
return false;
}
}
return true;
}
/**
* Returns the boolean {@code NDArray} for element-wise "Equals" comparison.
*
* Examples
*
*
* jshell> NDArray array = manager.ones(new Shape(1));
* jshell> array.eq(1);
* ND: (1) cpu() boolean
* [ true]
*
*
* @param n the number to compare
* @return the boolean {@code NDArray} for element-wise "Equals" comparison
*/
NDArray eq(Number n);
/**
* Returns the boolean {@code NDArray} for element-wise "Equals" comparison.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {0f, 1f, 3f});
* jshell> NDArray array2 = manager.arange(3f);
* jshell> array1.eq(array2);
* ND: (3) cpu() boolean
* [ true, true, false]
*
*
* @param other the {@code NDArray} to compare
* @return the boolean {@code NDArray} for element-wise "Equals" comparison
*/
NDArray eq(NDArray other);
/**
* Returns the boolean {@code NDArray} for element-wise "Not equals" comparison.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(4f).reshape(2, 2);
* jshell> array.neq(1);
* ND: (2, 2) cpu() boolean
* [[ true, false],
* [ true, true],
* ]
*
*
* @param n the number to compare
* @return the boolean {@code NDArray} for element-wise "Not equals" comparison
*/
NDArray neq(Number n);
/**
* Returns the boolean {@code NDArray} for element-wise "Not equals" comparison.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {1f, 2f});
* jshell> NDArray array2 = manager.create(new float[] {1f, 3f});
* jshell> array1.neq(array2);
* ND: (2) cpu() boolean
* [false, true]
* jshell> NDArray array1 = manager.create(new float[] {1f, 2f});
* jshell> NDArray array2 = manager.create(new float[] {1f, 3f, 1f, 4f}, new Shape(2, 2));
* jshell> array1.neq(array2); // broadcasting
* ND: (2, 2) cpu() boolean
* [[false, true],
* [false, true],
* ]
*
*
* @param other the {@code NDArray} to compare
* @return the boolean {@code NDArray} for element-wise "Not equals" comparison
*/
NDArray neq(NDArray other);
/**
* Returns the boolean {@code NDArray} for element-wise "Greater" comparison.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {4f, 2f});
* jshell> array.gt(2f);
* ND: (2) cpu() boolean
* [ true, false]
*
*
* @param n the number to compare
* @return the boolean {@code NDArray} for element-wise "Greater" comparison
*/
NDArray gt(Number n);
/**
* Returns the boolean {@code NDArray} for element-wise "Greater Than" comparison.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {4f, 2f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 2f});
* jshell> array1.neq(array2);
* ND: (2) cpu() boolean
* [ true, false]
*
*
* @param other the {@code NDArray} to compare
* @return the boolean {@code NDArray} for element-wis "Greater Than" comparison
*/
NDArray gt(NDArray other);
/**
* Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison.
*
*
* jshell> NDArray array = manager.create(new float[] {4f, 2f});
* jshell> array.gte(2f);
* ND: (2) cpu() boolean
* [ true, true]
*
*
* @param n the number to compare
* @return the boolean {@code NDArray} for element-wise "Greater or equals" comparison
*/
NDArray gte(Number n);
/**
* Returns the boolean {@code NDArray} for element-wise "Greater or equals" comparison.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {4f, 2f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 2f});
* jshell> array1.gte(array2);
* ND: (2) cpu() boolean
* [ true, true]
*
*
* @param other the number to compare
* @return the boolean {@code NDArray} for "Greater or equals" comparison
*/
NDArray gte(NDArray other);
/**
* Returns the boolean {@code NDArray} for element-wise "Less" comparison.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f});
* jshell> array.lt(2f);
* ND: (2) cpu() boolean
* [ true, false]
*
*
* @param n the number to compare
* @return the boolean {@code NDArray} for element-wise "Less" comparison
*/
NDArray lt(Number n);
/**
* Returns the boolean {@code NDArray} for element-wise "Less" comparison.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {1f, 2f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 2f});
* jshell> array1.lt(array2);
* ND: (2) cpu() boolean
* [ true, false]
*
*
* @param other the {@code NDArray} to compare
* @return the boolean {@code NDArray} for element-wise "Less" comparison
*/
NDArray lt(NDArray other);
/**
* Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f});
* jshell> array.lte(2f);
* ND: (2) cpu() boolean
* [ true, true]
*
*
* @param n the number to compare
* @return the boolean {@code NDArray} for element-wise "Less or equals" comparison
*/
NDArray lte(Number n);
/**
* Returns the boolean {@code NDArray} for element-wise "Less or equals" comparison.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {1f, 2f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 2f});
* jshell> array1.lte(array2);
* ND: (2) cpu() boolean
* [ true, true]
*
*
* @param other the {@code NDArray} to compare
* @return the boolean {@code NDArray} for element-wise "Less or equals" comparison
*/
NDArray lte(NDArray other);
////////////////////////////////////////
// Operations: Element Arithmetic
////////////////////////////////////////
/**
* Adds a number to this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f});
* jshell> array.add(2f);
* ND: (2) cpu() float32
* [3., 4.]
*
*
* @param n the number to add
* @return the result {@code NDArray}
*/
NDArray add(Number n);
/**
* Adds other {@code NDArray}s to this {@code NDArray} element-wise.
*
* The shapes of this {@code NDArray} and other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.arange(9f).reshape(3, 3);
* jshell> NDArray array2 = manager.arange(3f);
* jshell> array1.add(array2); // broadcasting
* ND: (3, 3) cpu() float32
* [[ 0., 2., 4.],
* [ 3., 5., 7.],
* [ 6., 8., 10.],
* ]
*
*
* @param other the other {@code NDArray}s to add
* @return the result {@code NDArray}
* @throws IllegalArgumentException others arrays must have at least one element
*/
NDArray add(NDArray other);
/**
* Subtracts a number from this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f});
* jshell> array.sub(2f);
* ND: (2) cpu() float32
* [-1., 0.]
*
*
* @param n the number to subtract from
* @return the result {@code NDArray}
*/
NDArray sub(Number n);
/**
* Subtracts the other {@code NDArray} from this {@code NDArray} element-wise.
*
* The shapes of this {@code NDArray} and other {@code NDArray}s must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.arange(9).reshape(3, 3);
* jshell> NDArray array2 = manager.arange(3);
* jshell> array1.sub(array2); // broadcasting
* ND: (3, 3) cpu() float32
* [[0., 0., 0.],
* [3., 3., 3.],
* [6., 6., 6.],
* ]
*
*
* @param other the other {@code NDArray} to subtract from
* @return the result {@code NDArray}
*/
NDArray sub(NDArray other);
/**
* Multiplies this {@code NDArray} by a number element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f});
* jshell> array.mul(3f);
* ND: (2) cpu() float32
* [3., 6.]
*
*
* @param n the number to multiply by
* @return the result {@code NDArray}
*/
NDArray mul(Number n);
/**
* Multiplies this {@code NDArray} by other {@code NDArray}s element-wise.
*
* The shapes of this {@code NDArray} and other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.arange(9f).reshape(3, 3);
* jshell> NDArray array2 = manager.arange(3f);
* jshell> array1.mul(array2); // broadcasting
* ND: (3, 3) cpu() float32
* [[ 0., 1., 4.],
* [ 0., 4., 10.],
* [ 0., 7., 16.],
* ]
*
*
* @param other the other {@code NDArray}s to multiply by
* @return the result {@code NDArray}
* @throws IllegalArgumentException others arrays must have at least one element
*/
NDArray mul(NDArray other);
/**
* Divides this {@code NDArray} by a number element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.div(4f);
* ND: (5) cpu() float32
* [0. , 0.25, 0.5 , 0.75, 1. ]
*
*
* @param n the number to divide by
* @return the result {@code NDArray}
*/
NDArray div(Number n);
/**
* Divides this {@code NDArray} by the other {@code NDArray} element-wise.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.arange(9f).reshape(3, 3);
* jshell> NDArray array2 = manager.ones(new Shape(3)).mul(10);
* jshell> array1.div(array2); // broadcasting
* ND: (3, 3) cpu() float32
* [[0. , 0.1, 0.2],
* [0.3, 0.4, 0.5],
* [0.6, 0.7, 0.8],
* ]
*
*
* @param other the other {@code NDArray} to divide by
* @return the result {@code NDArray}
*/
NDArray div(NDArray other);
/**
* Returns element-wise remainder of division.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(7f);
* jshell> array.mod(5f);
* ND: (7) cpu() float32
* [0., 1., 2., 3., 4., 0., 1.]
*
*
* @param n the divisor number
* @return the result {@code NDArray}
*/
NDArray mod(Number n);
/**
* Returns element-wise remainder of division.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {4f, 7f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 3f});
* jshell> array1.mod(array2);
* ND: (2) cpu() float32
* [0., 1.]
*
*
* @param other the divisor {@code NDArray}
* @return the result {@code NDArray}
*/
NDArray mod(NDArray other);
/**
* Takes the power of this {@code NDArray} with a number element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.pow(4f);
* ND: (6) cpu() float32
* [ 0., 1., 8., 27., 64., 125.]
*
*
* @param n the number to take the power with
* @return the result {@code NDArray}
*/
NDArray pow(Number n);
/**
* Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array1 = manager.arange(6f).reshape(3, 2);
* jshell> NDArray array2 = manager.create(new float[] {2f, 3f});
* jshell> array1.pow(array2); // broadcasting
* ND: (3, 2) cpu() float32
* [[ 0., 1.],
* [ 4., 27.],
* [ 16., 125.],
* ]
*
*
* @param other the other {@code NDArray} to take the power with
* @return the result {@code NDArray}
*/
NDArray pow(NDArray other);
/**
* Computes this * log(other).
*
* @param other other the other {@code NDArray}
* @return the result {@code NDArray}
*/
NDArray xlogy(NDArray other);
/**
* Adds a number to this {@code NDArray} element-wise in place.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f});
* jshell> array.addi(2f);
* ND: (2) cpu() float32
* [3., 4.]
* jshell> array;
* ND: (2) cpu() float32
* [3., 4.]
*
*
* @param n the number to add
* @return the result {@code NDArray}
*/
NDArray addi(Number n);
/**
* Adds other {@code NDArray}s to this {@code NDArray} element-wise in place.
*
* The shapes of this {@code NDArray} and other {@code NDArray}s must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {1f, 2f});
* jshell> NDArray array2 = manager.create(new float[] {3f, 4f});
* jshell> array1.addi(array2);
* ND: (2) cpu() float32
* [4., 6.]
* jshell> array;
* ND: (2) cpu() float32
* [4., 6.]
*
*
* @param other the other {@code NDArray}s to add
* @return the result {@code NDArray}
*/
NDArray addi(NDArray other);
/**
* Subtracts a number from this {@code NDArray} element-wise in place.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f});
* jshell> array.subi(2f);
* ND: (2) cpu() float32
* [-1., 0.]
* jshell> array;
* ND: (2) cpu() float32
* [-1., 0.]
*
*
* @param n the number to subtract
* @return the result {@code NDArray}
*/
NDArray subi(Number n);
/**
* Subtracts the other {@code NDArray} from this {@code NDArray} element-wise in place.
*
* The shapes of this {@code NDArray} and other {@code NDArray}s must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.arange(9f).reshape(3, 3);
* jshell> NDArray array2 = manager.arange(3f);
* jshell> array1.subi(array2); // broadcasting
* ND: (3, 3) cpu() float32
* [[0., 0., 0.],
* [3., 3., 3.],
* [6., 6., 6.],
* ]
* jshell> array1;
* [[0., 0., 0.],
* [3., 3., 3.],
* [6., 6., 6.],
* ]
*
*
* @param other the other {@code NDArray} to subtract from
* @return the result {@code NDArray}
*/
NDArray subi(NDArray other);
/**
* Multiplies this {@code NDArray} by a number element-wise in place.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f});
* jshell> array.muli(3f);
* ND: (2) cpu() float32
* [3., 6.]
* jshell> array;
* ND: (2) cpu() float32
* [3., 6.]
*
*
* @param n the number to multiply by
* @return the result {@code NDArray}
*/
NDArray muli(Number n);
/**
* Multiplies this {@code NDArray} by other {@code NDArray} element-wise in place.
*
* The shapes of this {@code NDArray} and other {@code NDArray}s must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.arange(9f).reshape(3, 3);
* jshell> NDArray array2 = manager.arange(3f);
* jshell> array1.muli(array2); // broadcasting
* ND: (3, 3) cpu() float32
* [[ 0., 1., 4.],
* [ 0., 4., 10.],
* [ 0., 7., 16.],
* ]
* jshell> array1;
* ND: (3, 3) cpu() float32
* [[ 0., 1., 4.],
* [ 0., 4., 10.],
* [ 0., 7., 16.],
* ]
*
*
* @param other the other NDArrays to multiply with
* @return the result {@code NDArray}
*/
NDArray muli(NDArray other);
/**
* Divides this {@code NDArray} by a number element-wise in place.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.divi(4f);
* ND: (5) cpu() float32
* [0. , 0.25, 0.5 , 0.75, 1. ]
* jshell> array;
* ND: (5) cpu() float32
* [0. , 0.25, 0.5 , 0.75, 1. ]
*
*
* @param n the number to divide values by
* @return the array after applying division operation
*/
NDArray divi(Number n);
/**
* Divides this {@code NDArray} by the other {@code NDArray} element-wise in place.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.arange(9f).reshape(3, 3);
* jshell> NDArray array2 = manager.ones(new Shape(3)).mul(10);
* jshell> array1.divi(array2); // broadcasting
* ND: (3, 3) cpu() float32
* [[0. , 0.1, 0.2],
* [0.3, 0.4, 0.5],
* [0.6, 0.7, 0.8],
* ]
* jshell> array1;
* [[0. , 0.1, 0.2],
* [0.3, 0.4, 0.5],
* [0.6, 0.7, 0.8],
* ]
*
*
* @param other the other {@code NDArray} to divide by
* @return the result of the divide
*/
NDArray divi(NDArray other);
/**
* Returns element-wise remainder of division in place.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(7f);
* jshell> array.modi(5f);
* ND: (7) cpu() float32
* [0., 1., 2., 3., 4., 0., 1.]
* jshell> array;
* ND: (7) cpu() float32
* [0., 1., 2., 3., 4., 0., 1.]
*
*
* @param n the divisor number
* @return the result {@code NDArray}
*/
NDArray modi(Number n);
/**
* Returns in place element-wise remainder of division in place.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {4f, 7f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 3f});
* jshell> array1.modi(array2);
* ND: (2) cpu() float32
* [0., 1.]
* jshell> array1;
* ND: (2) cpu() float32
* [0., 1.]
*
*
* @param other the divisor {@code NDArray}
* @return the result of the divide
*/
NDArray modi(NDArray other);
/**
* Takes the power of this {@code NDArray} with a number element-wise in place.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.powi(4f);
* ND: (6) cpu() float32
* [ 0., 1., 8., 27., 64., 125.]
* jshell> array;
* ND: (6) cpu() float32
* [ 0., 1., 8., 27., 64., 125.]
*
*
* @param n the number to raise the power to
* @return the result {@code NDArray}
*/
NDArray powi(Number n);
/**
* Takes the power of this {@code NDArray} with the other {@code NDArray} element-wise in place.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.arange(6f).reshape(3, 2);
* jshell> NDArray array2 = manager.create(new float[] {2f, 3f});
* jshell> array1.powi(array2); // broadcasting
* ND: (3, 2) cpu() float32
* [[ 0., 1.],
* [ 4., 27.],
* [ 16., 125.],
* ]
* jshell> array1;
* ND: (3, 2) cpu() float32
* [[ 0., 1.],
* [ 4., 27.],
* [ 16., 125.],
* ]
*
*
* @param other the other {@code NDArray} to take the power with
* @return the result {@code NDArray}
*/
NDArray powi(NDArray other);
/**
* Returns the element-wise sign.
*
* @return the result {@code NDArray}
*/
NDArray sign();
/**
* Returns the element-wise sign in-place.
*
* @return the result {@code NDArray}
*/
NDArray signi();
/**
* Returns the maximum of this {@code NDArray} and a number element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {2f, 3f, 4f});
* jshell> array.maximum(3f);
* ND: (3) cpu() float32
* [3., 3., 4.]
*
*
* @param n the number to be compared
* @return the maximum of this {@code NDArray} and a number element-wise
*/
NDArray maximum(Number n);
/**
* Returns the maximum of this {@code NDArray} and the other {@code NDArray} element-wise.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {2f, 3f, 4f});
* jshell> NDArray array2 = manager.create(new float[] {1f, 5f, 2f});
* jshell> array1.maximum(array2);
* ND: (3) cpu() float32
* [2., 5., 4.]
* jshell> NDArray array1 = manager.eye(2);
* jshell> NDArray array2 = manager.create(new float[] {0.5f, 2f});
* jshell> array1.maximum(array2); // broadcasting
* ND: (2, 2) cpu() float32
* [[1. , 2. ],
* [0.5, 2. ],
* ]
*
*
* @param other the {@code NDArray} to be compared
* @return the maximum of this {@code NDArray} and the other {@code NDArray} element-wise
*/
NDArray maximum(NDArray other);
/**
* Returns the minimum of this {@code NDArray} and a number element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {2f, 3f, 4f});
* jshell> array.minimum(3f);
* ND: (3) cpu() float32
* [2., 3., 3.]
*
*
* @param n the number to be compared
* @return the minimum of this {@code NDArray} and a number element-wise
*/
NDArray minimum(Number n);
/**
* Returns the minimum of this {@code NDArray} and the other {@code NDArray} element-wise.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {2f, 3f, 4f});
* jshell> NDArray array2 = manager.create(new float[] {1f, 5f, 2f});
* jshell> array1.minimum(array2);
* ND: (3) cpu() float32
* [1., 3., 2.]
* jshell> NDArray array1 = manager.eye(2);
* jshell> NDArray array2 = manager.create(new float[] {0.5f, 2f});
* jshell> array1.minimum(array2); // broadcasting
* ND: (2, 2) cpu() float32
* [[0.5, 0. ],
* [0. , 1. ],
* ]
*
*
* @param other the {@code NDArray} to be compared
* @return the minimum of this {@code NDArray} and the other {@code NDArray} element-wise
*/
NDArray minimum(NDArray other);
////////////////////////////////////////
// Operations: Basic Numeric
////////////////////////////////////////
/**
* Returns the numerical negative {@code NDArray} element-wise.
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.neg();
* ND: (5) cpu() float32
* [-0., -1., -2., -3., -4.]
*
*
* @return the result {@code NDArray}
*/
NDArray neg();
/**
* Returns the numerical negative {@code NDArray} element-wise in place.
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.negi();
* jshell> array;
* ND: (5) cpu() float32
* [-0., -1., -2., -3., -4.]
*
*
* @return the result {@code NDArray}
*/
NDArray negi();
/**
* Returns the absolute value of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {-1f, -2f});
* jshell> array.abs();
* ND: (2) cpu() float32
* [1., 2.]
*
*
* @return the result {@code NDArray}
*/
NDArray abs();
/**
* Returns the square of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {2f, -3f});
* jshell> array.square();
* ND: (2) cpu() float32
* [4., 9.]
*
*
* @return the result {@code NDArray}
*/
NDArray square();
/**
* Returns the square root of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {4f});
* jshell> array.sqrt();
* ND: (1) cpu() float32
* [2., ]
*
*
* @return the result {@code NDArray}
*/
NDArray sqrt();
/**
* Returns the cube-root of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 8f, 27f});
* jshell> array.cbrt();
* ND: (3) cpu() float32
* [1., 2., 3.]
*
*
* @return the result {@code NDArray}
*/
NDArray cbrt();
/**
* Returns the floor of this {@code NDArray} element-wise.
*
* The floor of the scalar x is the largest integer i, such that i <= x.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f});
* jshell> array.floor();
* ND: (7) cpu() float32
* [-2., -2., -1., 0., 1., 1., 2.]
*
*
* @return the result {@code NDArray}
*/
NDArray floor();
/**
* Returns the ceiling of this {@code NDArray} element-wise.
*
* The ceil of the scalar x is the smallest integer i, such that i >= x.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f});
* jshell> array.ceil();
* ND: (7) cpu() float32
* [-1., -1., -0., 1., 2., 2., 2.]
*
*
* @return the result {@code NDArray}
*/
NDArray ceil();
/**
* Returns the round of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f});
* jshell> array.round();
* ND: (7) cpu() float32
* [-2., -2., -0., 0., 2., 2., 2.]
*
*
* @return the result {@code NDArray}
*/
NDArray round();
/**
* Returns the truncated value of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f});
* jshell> array.trunc();
* ND: (7) cpu() float32
* [-1., -1., -0., 0., 1., 1., 2.]
*
*
* @return the result {@code NDArray}
*/
NDArray trunc();
/**
* Returns the exponential value of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 2.5f});
* jshell> array.exp();
* ND: (2) cpu() float32
* [ 1. , 12.1825]
*
*
* @return the result {@code NDArray}
*/
NDArray exp();
/**
* Return the log of the absolute value of the gamma function of this {@code NDArray}
* element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0.5f, 1f, 1.5f});
* jshell> array.gammaln();
* ND: (2) cpu() float32
* [ 0.5724, 0.0000, -0.1208]
*
*
* @return the result {@code NDArray}
*/
NDArray gammaln();
/**
* Returns the natural logarithmic value of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 2.5f});
* jshell> array.log();
* ND: (2) cpu() float32
* [ -inf, 0.9163]
*
*
* @return the result {@code NDArray}
*/
NDArray log();
/**
* Returns the base 10 logarithm of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1000f, 1f, 150f});
* jshell> array.log10();
* ND: (3) cpu() float32
* [3. , 0. , 2.1761]
*
*
* @return the result {@code NDArray}
*/
NDArray log10();
/**
* Returns the base 2 logarithm of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {8, 1f, 5f});
* jshell> array.log2();
* ND: (3) cpu() float32
* [3. , 0. , 2.3219]
*
*
* @return the result {@code NDArray}
*/
NDArray log2();
/**
* Returns the trigonometric sine of this {@code NDArray} element-wise.
*
* The input should be in radians (2 Pi radians equals 360 degrees).
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 30f, 45f, 60f, 90f});
* jshell> array = array.mul(Math.PI).div(180f);
* jshell> array.sin();
* ND: (5) cpu() float32
* [0. , 0.5 , 0.7071, 0.866 , 1. ]
*
*
* @return the result {@code NDArray}
*/
NDArray sin();
/**
* Returns the trigonometric cosine of this {@code NDArray} element-wise.
*
* The input should be in radians (2 Pi radians equals 360 degrees).
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new double[] {0, Math.PI/2, Math.PI});
* jshell> array.cos();
* ND: (3) cpu() float64
* [ 1.0000000e+00, 6.1232340e-17, -1.0000000e+00],
*
*
* @return the result {@code NDArray}
*/
NDArray cos();
/**
* Returns the trigonometric tangent of this {@code NDArray} element-wise.
*
* The input should be in radians (2 Pi radians equals 360 degrees).
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new double[] {-Math.PI, Math.PI/2, Math.PI});
* jshell> array.tan();
* ND: (3) cpu() float64
* [ 1.2246468e-16, 1.6331239e+16, -1.2246468e-16],
*
*
* @return the result {@code NDArray}
*/
NDArray tan();
/**
* Returns the inverse trigonometric sine of this {@code NDArray} element-wise.
*
* The input should be in the range [-1, 1]. The output is in the closed interval of [-Pi/2,
* Pi/2].
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, -1f, 0f});
* jshell> array.asin();
* ND: (3) cpu() float64
* [ 1.5708, -1.5708, 0. ]
*
*
* @return the result {@code NDArray}
*/
NDArray asin();
/**
* Returns the inverse trigonometric cosine of this {@code NDArray} element-wise.
*
* The input should be in the range [-1, 1]. The output is in the closed interval of [-Pi/2,
* Pi/2].
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, -1f});
* jshell> array.acos();
* ND: (2) cpu() float64
* [0. , 3.1416]
*
*
* @return the result {@code NDArray}
*/
NDArray acos();
/**
* Returns the inverse trigonometric tangent of this {@code NDArray} element-wise.
*
* The input should be in the range [-1, 1]. The output is in the closed interval of [-Pi/2,
* Pi/2].
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f});
* jshell> array.atan();
* ND: (2) cpu() float64
* [0. , 0.7854]
*
*
* @return the result {@code NDArray}
*/
NDArray atan();
/**
* Returns the element-wise arc-tangent of this/other choosing the quadrant correctly.
*
* Examples
*
*
* jshell> NDArray x = manager.create(new float[] {0f, 1f});
* jshell> NDArray y = manager.create(new float[] {0f, -6f});
* jshell> x.atan2(y);
* ND: (2) cpu() float64
* [0. , 2.9764]
*
*
* @param other The other {@code NDArray}
* @return the result {@code NDArray}
*/
NDArray atan2(NDArray other);
/**
* Returns the hyperbolic sine of this {@code NDArray} element-wise.
*
* sinh(x)=0.5*(exp(x) - exp(-x))
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new double[] {0, Math.PI});
* jshell> array.sinh();
* ND: (2) cpu() float64
* [ 0. , 11.5487]
*
*
* @return the result {@code NDArray}
*/
NDArray sinh();
/**
* Returns the hyperbolic cosine of this {@code NDArray} element-wise.
*
* cosh(x)=0.5*(exp(x)+exp(-x))
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new double[] {0, Math.PI});
* jshell> array.cosh();
* ND: (2) cpu() float64
* [ 1. , 11.592 ]
*
*
* @return the result {@code NDArray}
*/
NDArray cosh();
/**
* Returns the hyperbolic tangent of this {@code NDArray} element-wise.
*
* tanh(x)=sinh(x)/cosh(x)
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new double[] {0, Math.PI});
* jshell> array.tanh();
* ND: (2) cpu() float64
* [0. , 0.9963]
*
*
* @return the result {@code NDArray}
*/
NDArray tanh();
/**
* Returns the inverse hyperbolic sine of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new double[] {Math.E, 10});
* jshell> array.asinh();
* ND: (2) cpu() float64
* [1.7254, 2.9982]
*
*
* @return the result {@code NDArray}
*/
NDArray asinh();
/**
* Returns the inverse hyperbolic cosine of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new double[] {Math.E, 10});
* jshell> array.acosh();
* ND: (2) cpu() float64
* [1.6575, 2.9932]
*
*
* @return the result {@code NDArray}
*/
NDArray acosh();
/**
* Returns the inverse hyperbolic tangent of this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new double[] {0, -0.5});
* jshell> array.atanh();
* ND: (2) cpu() float64
* [ 0. , -0.5493]
*
*
* @return the result {@code NDArray}
*/
NDArray atanh();
/**
* Converts this {@code NDArray} from radians to degrees element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(6f).mul(Math.PI / 3);
* jshell> array.toDegrees();
* ND: (6) cpu() float32
* [ 0., 60., 120., 180., 240., 300.]
*
*
* @return the result {@code NDArray}
*/
NDArray toDegrees();
/**
* Converts this {@code NDArray} from degrees to radians element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(6f).mul(60);
* jshell> array.toRadians();
* ND: (6) cpu() float32
* [0. , 1.0472, 2.0944, 3.1416, 4.1888, 5.236 ]
*
*
* @return the result {@code NDArray}
*/
NDArray toRadians();
////////////////////////////////////////
// Operations: Reduction
////////////////////////////////////////
/**
* Returns the maximum of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(4f).reshape(2,2);
* jshell> array;
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [2., 3.],
* ]
* jshell> array.max(); // Maximum of the flattened array
* ND: () cpu() float32
* 3.
* jshell> array.max().getFloat() // Use getFloat() to get native float
* 3.0
*
*
* @return the maximum of this {@code NDArray}
*/
NDArray max();
/**
* Returns the maximum of this {@code NDArray} along given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(4f).reshape(2,2);
* jshell> array;
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [2., 3.],
* ]
* jshell> array.max(new int[]{0}); // Maximum along the first axis
* ND: (2) cpu() float32
* [2., 3.]
* jshell> array.max(new int[]{1}); // Maximum along the second axis
* ND: (2) cpu() float32
* [1., 3.]
*
*
* @param axes the axes along which to operate
* @return the maximum of this {@code NDArray} with the specified axes removed from the Shape
* containing the max
* @see NDArray#max(int[], boolean)
*/
default NDArray max(int[] axes) {
return max(axes, false);
}
/**
* Returns the maximum of this {@code NDArray} along given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(4f).reshape(2,2);
* jshell> array;
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [2., 3.],
* ]
* jshell> array.max(new int[]{0}, true); // Maximum along the first axis and keep dimension
* ND: (1, 2) cpu() float32
* [[2., 3.],
* ]
* jshell> array.max(new int[]{1}, true); // Maximum along the second axis and keep dimension
* ND: (2, 1) cpu() float32
* [[1.],
* [3.],
* ]
*
*
* @param axes the axes along which to operate
* @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
* false} to squeeze the values out of the output array.
* @return the maximum of this {@code NDArray}
*/
NDArray max(int[] axes, boolean keepDims);
/**
* Returns the minimum of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(4f).reshape(2,2);
* jshell> array;
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [2., 3.],
* ]
* jshell> array.min(); // Minimum of the flattened array
* ND: () cpu() float32
* 0.
* jshell> array.min().getFloat(); // Use getFloat() to get native float
* 0.0
*
*
* @return the minimum of this {@code NDArray}
*/
NDArray min();
/**
* Returns the minimum of this {@code NDArray} along given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(4f).reshape(2,2);
* jshell> array;
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [2., 3.],
* ]
* jshell> array.min(new int[]{0}); // Minimum along the first axis
* ND: (2) cpu() float32
* [0., 1.]
* jshell> array.min(new int[]{1}); // Minimum along the second axis
* ND: (2) cpu() float32
* [0., 2.]
*
*
* @param axes the axes along which to operate
* @return the minimum of this {@code NDArray} with the specified axes removed from the Shape
* containing the min
* @see NDArray#min(int[], boolean)
*/
default NDArray min(int[] axes) {
return min(axes, false);
}
/**
* Returns the minimum of this {@code NDArray} along given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(4f).reshape(2,2);
* jshell> array
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [2., 3.],
* ]
* jshell> array.min(new int[]{0}, true) // Minimum along the first axis and keep dimension
* ND: (1, 2) cpu() float32
* [[0., 1.],
* ]
* jshell> array.min(new int[]{1}, true) // Minimum along the second axis and keep dimension
* ND: (2, 1) cpu() float32
* [[0.],
* [2.],
* ]
*
*
* @param axes the axes along which to operate
* @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
* false} to squeeze the values out of the output array
* @return the minimum of this {@code NDArray}
*/
NDArray min(int[] axes, boolean keepDims);
/**
* Returns the sum of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0.5f, 1.5f});
* jshell> array.sum();
* ND: () cpu() float32
* 2.
* jshell> array.sum().getFloat(); // Use getFloat() to get native float
* 2.0
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 0f, 5f}, new Shape(2, 2));
* jshell> array.sum();
* ND: () cpu() float32
* 6.
*
*
* @return the sum of this {@code NDArray}
*/
NDArray sum();
/**
* Returns the sum of this {@code NDArray} along given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 0f, 5f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [0., 5.],
* ]
* jshell> array.sum(new int[] {0});
* ND: (2) cpu() float32
* [0., 6.]
* jshell> array.sum(new int[] {1});
* ND: (2) cpu() float32
* [1., 5.]
*
*
* @param axes the axes along which to operate
* @return the sum of this {@code NDArray} with the specified axes removed from the Shape
* containing the sum
* @see NDArray#sum(int[], boolean)
*/
default NDArray sum(int[] axes) {
return sum(axes, false);
}
/**
* Returns the sum of this {@code NDArray} along given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 0f, 5f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [0., 5.],
* ]
* jshell> array.sum(new int[] {0}, true);
* ND: (1, 2) cpu() float32
* [[0., 6.],
* ]
* jshell> array.sum(new int[] {1}, true);
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [0., 5.],
* ]
*
*
* @param axes the axes along which to operate
* @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
* false} to squeeze the values out of the output array
* @return the sum of this {@code NDArray}
*/
NDArray sum(int[] axes, boolean keepDims);
/**
* Returns the cumulative product of elements of input in the dimension dim. For example, if
* input is a vector of size N, the result will also be a vector of size N, with elements. [x1,
* x1 * x2, x1 * x2 *x3 ...]
*
* @param axis the axis along which to operate
* @return the cumulative product of this
*/
NDArray cumProd(int axis);
/**
* Returns the cumulative product of elements of input in the dimension dim. For example, if
* input is a vector of size N, the result will also be a vector of size N, with elements. [x1,
* x1 * x2, x1 * x2 *x3 ...]
*
* @param axis the axis along which to operate
* @param dataType the datatype of the output
* @return the cumulative product of this
*/
NDArray cumProd(int axis, DataType dataType);
/**
* Returns the product of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {2f, 3f});
* jshell> array.prod();
* ND: () cpu() float32
* 6.
* jshell> array.prod().getFloat(); // Use getFloat to get native float
* 6.0
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.prod();
* ND: () cpu() float32
* 24.
*
*
* @return the product of this {@code NDArray}
*/
NDArray prod();
/**
* Returns the product of this {@code NDArray} elements over the given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2.],
* [3., 4.],
* ]
* jshell> array.prod(new int[] {0});
* ND: (2) cpu() float32
* [3., 8.]
* jshell> array.prod(new int[] {1});
* ND: (2) cpu() float32
* [ 2., 12.]
*
*
* @param axes the axes along which to operate
* @return the product of this {@code NDArray} with the specified axes removed from the Shape
* containing the prod
* @see NDArray#prod(int[], boolean)
*/
default NDArray prod(int[] axes) {
return prod(axes, false);
}
/**
* Returns the product of this {@code NDArray} elements over the given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2.],
* [3., 4.],
* ]
* jshell> array.prod(new int[] {0}, true);
* ND: (1, 2) cpu() float32
* [[3., 8.],
* ]
* jshell> array.prod(new int[] {1}, true);
* ND: (2, 1) cpu() float32
* [[ 2.],
* [12.],
* ]
*
*
* @param axes the axes along which to operate
* @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
* false} to squeeze the values out of the output array
* @return the product of this {@code NDArray}
*/
NDArray prod(int[] axes, boolean keepDims);
/**
* Returns the average of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {2f, 3f});
* jshell> array.mean();
* ND: () cpu() float32
* 2.5
* jshell> array.mean().getFloat(); // Use getFloat() to get native float
* 2.5
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.mean();
* ND: () cpu() float32
* 2.5
*
*
* @return the average of this {@code NDArray}
*/
NDArray mean();
/**
* Returns the average of this {@code NDArray} along given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2.],
* [3., 4.],
* ]
* jshell> array.mean(new int[] {0});
* ND: (2) cpu() float32
* [2., 3.]
* jshell> array.mean(new int[] {1});
* ND: (2) cpu() float32
* [1.5, 3.5]
*
*
* @param axes the axes along which to operate
* @return the average of this {@code NDArray} with the specified axes removed from the Shape
* containing the mean
* @see NDArray#mean(int[], boolean)
*/
default NDArray mean(int[] axes) {
return mean(axes, false);
}
/**
* Returns the average of this {@code NDArray} along given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2.],
* [3., 4.],
* ]
* jshell> array.mean(new int[] {0}, true);
* ND: (1, 2) cpu() float32
* [[2., 3.],
* ]
* jshell> array.mean(new int[] {1}, true);
* ND: (2, 1) cpu() float32
* [[1.5],
* [3.5],
* ]
*
*
* @param axes the axes along which to operate
* @param keepDims {@code true} to keep the specified axes as size 1 in the output array, {@code
* false} to squeeze the values out of the output array
* @return the average of this {@code NDArray}
*/
NDArray mean(int[] axes, boolean keepDims);
/**
* Performs Lp normalization of the array over specified dimension.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1, 2, 3, 4, 5, 6}, new Shape(2, 3));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2., 3.],
* [4., 5., 6.],
* ]
* jshell> array.normalize();
* ND: (2, 3) cpu() float32
* [[0.2673, 0.5345, 0.8018],
* [0.4558, 0.5698, 0.6838],
* ]
*
*
* @return the normalized {@code NDArray}
*/
default NDArray normalize() {
return normalize(2, 1, 1e-12);
}
/**
* Performs Lp normalization of the array over specified dimension.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1, 2, 3, 4, 5, 6}, new Shape(2, 3));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2., 3.],
* [4., 5., 6.],
* ]
* jshell> array.normalize(2, 1);
* ND: (2, 3) cpu() float32
* [[0.2673, 0.5345, 0.8018],
* [0.4558, 0.5698, 0.6838],
* ]
*
*
* @param exponent the exponent value in the norm formulation
* @param dim the dimension to reduce
* @return the normalized {@code NDArray}
*/
default NDArray normalize(double exponent, long dim) {
return normalize(exponent, dim, 1e-12);
}
/**
* Performs Lp normalization of the array over specified dimension.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1, 2, 3, 4, 5, 6}, new Shape(2, 3));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2., 3.],
* [4., 5., 6.],
* ]
* jshell> array.normalize(2, 1, 1e-12);
* ND: (2, 3) cpu() float32
* [[0.2673, 0.5345, 0.8018],
* [0.4558, 0.5698, 0.6838],
* ]
*
*
* @param exponent the exponent value in the norm formulation
* @param dim the dimension to reduce
* @param eps the small value to avoid division by zero
* @return the normalized {@code NDArray}
*/
NDArray normalize(double exponent, long dim, double eps);
/**
* Rotates an array by 90 degrees in the plane specified by axes.
*
* Rotation direction is from the first towards the second axis.
*
* @param times Number of times the array is rotated by 90 degrees.
* @param axes The array is rotated in the plane defined by the axes. Axes must be different.
* @return the rotated NDArray
*/
NDArray rotate90(int times, int[] axes);
/**
* Returns the sum along diagonals of this {@code NDArray}.
*
*
If this {@code NDArray} is 2-D, the sum along its diagonal is returned. If the {@link
* NDArray} has more than two dimensions, then the axes specified by axis1 and axis2 are used to
* determine the 2-D sub-arrays whose traces are returned. The {@link Shape} of the resulting
* {@link NDArray} is the same as that of a with axis1 and axis2 removed.
*
*
Examples
*
*
* jshell> NDArray array = manager.eye(3);
* jshell> array;
* ND: (3, 3) cpu() float32
* [[1., 0., 0.],
* [0., 1., 0.],
* [0., 0., 1.],
* ]
* jshell> array.trace();
* ND: () cpu() float32
* 3.
* jshell> NDArray array = manager.arange(8f).reshape(2, 2, 2);
* jshell> array;
* ND: (2, 2, 2) cpu() float32
* [[[0., 1.],
* [2., 3.],
* ],
* [[4., 5.],
* [6., 7.],
* ],
* ]
* jshell> array.trace();
* ND: (2) cpu() float32
* [6., 8.]
*
*
* @return the sum along diagonals of this {@code NDArray}
*/
default NDArray trace() {
return trace(0, 0, 1);
}
/**
* Returns the sum along diagonals of this {@code NDArray}.
*
* If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is
* returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more
* than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D
* sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as
* this {@code NDArray} with axis1 and axis2 removed.
*
*
Examples
*
*
* jshell> NDArray array = manager.eye(3);
* jshell> array;
* ND: (3, 3) cpu() float32
* [[1., 0., 0.],
* [0., 1., 0.],
* [0., 0., 1.],
* ]
* jshell> array.trace(1);
* ND: () cpu() float32
* 0.
* jshell> NDArray array = manager.arange(8f).reshape(2, 2, 2);
* jshell> array;
* ND: (2, 2, 2) cpu() float32
* [[[0., 1.],
* [2., 3.],
* ],
* [[4., 5.],
* [6., 7.],
* ],
* ]
* jshell> array.trace(1);
* ND: (2) cpu() float32
* [2., 3.]
*
*
* @param offset offset of the diagonal from the main diagonal. Can be both positive and
* negative.
* @return the sum along diagonals of this {@code NDArray}
*/
default NDArray trace(int offset) {
return trace(offset, 0, 1);
}
/**
* Returns the sum along diagonals of this {@code NDArray}.
*
* If this {@code NDArray} is 2-D, the sum along its diagonal with the given offset is
* returned, i.e., the sum of elements a[i,i+offset] for all i. If this {@code NDArray} has more
* than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D
* sub-arrays whose traces are returned. The {@link Shape} of the resulting array is the same as
* this {@code NDArray} with axis1 and axis2 removed.
*
*
Examples
*
*
* jshell> NDArray array = manager.arange(8f).reshape(2, 2, 2);
* jshell> array;
* ND: (2, 2, 2) cpu() float32
* [[[0., 1.],
* [2., 3.],
* ],
* [[4., 5.],
* [6., 7.],
* ],
* ]
* jshell> array.trace(1,1,2);
* ND: (2) cpu() float32
* [1., 5.]
*
*
* @param offset offset of the diagonal from the main diagonal. Can be both positive and
* negative.
* @param axis1 axes to be used as the first axis of the 2-D sub-arrays from which the diagonals
* should be taken
* @param axis2 axes to be used as the second axis of the 2-D sub-arrays from which the
* diagonals should be taken
* @return the sum along diagonals of this {@code NDArray}
*/
NDArray trace(int offset, int axis1, int axis2);
////////////////////////////////////////
// Operations: Shapes and Arrays Manipulation
////////////////////////////////////////
/**
* Splits this {@code NDArray} into multiple sub{@code NDArray}s given sections along first
* axis.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(9f);
* jshell> array.split(3).forEach(System.out::println);
* ND: (3) cpu() float32
* [0., 1., 2.]
*
* ND: (3) cpu() float32
* [3., 4., 5.]
*
* ND: (3) cpu() float32
* [6., 7., 8.]
*
*
* @param sections this {@code NDArray} will be divided into N (sections) equal {@code NDArray}
* @return an {@link NDList} with size(axis) {@code NDArray}s with {@link Shape} {@code
* this.shape.remove(axis) }
* @see NDArray#split(long, int)
*/
default NDList split(long sections) {
return split(sections, 0);
}
/**
* Splits this {@code NDArray} into multiple sub-{@code NDArray}s given indices along first
* axis.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(8f);
* jshell> array.split(new int[] {3, 5, 6}).forEach(System.out::println);
* ND: (3) cpu() float32
* [0., 1., 2.]
*
* ND: (2) cpu() float32
* [3., 4.]
*
* ND: (1) cpu() float32
* [5.]
*
* ND: (2) cpu() float32
* [6., 7.]
*
*
* @param indices the entries indicate where along axis this {@code NDArray} is split. If an
* index exceeds the dimension of this {@code NDArray} along axis, an empty sub-{@link
* NDArray} is returned correspondingly.
* @return an NDList with size(axis) {@code NDArray}s with {@link Shape} {@code
* this.shape.remove(axis) }
* @see NDArray#split(long[], int)
*/
default NDList split(long[] indices) {
return split(indices, 0);
}
/**
* Splits this {@code NDArray} into multiple sub{@code NDArray}s given sections along the given
* axis.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(18f).reshape(2, 9);
* jshell> array;
* ND: (2, 9) cpu() float32
* [[ 0., 1., 2., 3., 4., 5., 6., 7., 8.],
* [ 9., 10., 11., 12., 13., 14., 15., 16., 17.],
* ]
* jshell> array.split(3, 1).forEach(System.out::println);
* ND: (2, 3) cpu() float32
* [[ 0., 1., 2.],
* [ 9., 10., 11.],
* ]
*
* ND: (2, 3) cpu() float32
* [[ 3., 4., 5.],
* [12., 13., 14.],
* ]
*
* ND: (2, 3) cpu() float32
* [[ 6., 7., 8.],
* [15., 16., 17.],
* ]
*
*
* @param sections this {@code NDArray} will be divided into N (sections) equal arrays along
* axis
* @param axis the axis to split along
* @return an {@link NDList} with numOutputs {@code NDArray}s with {@link Shape} {@code
* (this.shape.axis /= axis) }
* @throws IllegalArgumentException thrown if the numOutputs does not equally divide the given
* axis
*/
default NDList split(long sections, int axis) {
long axisSize = getShape().getShape()[axis];
if (axisSize % sections != 0) {
throw new IllegalArgumentException("array split does not result in an equal division");
}
long sectionSize = axisSize / sections;
long[] indices = LongStream.range(0, sections).map(i -> i * sectionSize).toArray();
return split(indices, axis);
}
/**
* Splits this {@code NDArray} into multiple sub-{@code NDArray}s given indices along given
* axis.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(18f).reshape(2, 9);
* jshell> array;
* ND: (2, 9) cpu() float32
* [[ 0., 1., 2., 3., 4., 5., 6., 7., 8.],
* [ 9., 10., 11., 12., 13., 14., 15., 16., 17.],
* ]
* jshell> array.split(new int[] {2,4,5}, 1).forEach(System.out::println);
* ND: (2, 2) cpu() float32
* [[ 0., 1.],
* [ 9., 10.],
* ]
*
* ND: (2, 2) cpu() float32
* [[ 2., 3.],
* [11., 12.],
* ]
*
* ND: (2, 1) cpu() float32
* [[ 4.],
* [13.],
* ]
*
* ND: (2, 4) cpu() float32
* [[ 5., 6., 7., 8.],
* [14., 15., 16., 17.],
* ]
*
*
* @param indices the entries indicate where along axis this {@code NDArray} is split. If an
* index exceeds the dimension of this {@code NDArray} along axis, an empty sub-array is
* returned correspondingly
* @param axis the axis to split along
* @return an {@link NDList} with numOutputs {@code NDArray}s with {@link Shape} {@code
* (this.shape.axis /= axis) }
*/
NDList split(long[] indices, int axis);
/**
* Flattens this {@code NDArray} into a 1-D {@code NDArray} in row-major order.
*
* To flatten in column-major order, first transpose this {@code NDArray}
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[]{1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.flatten();
* ND: (4) cpu() float32
* [1., 2., 3., 4.]
*
*
* @return a 1-D {@code NDArray} of equal size
*/
NDArray flatten();
/**
* Flattens this {@code NDArray} into a partially flatten {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[]{1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f}, new Shape(2, 2, 2));
* jshell> array.flatten(0, 1);
* ND: (4) cpu() float32
* [[1., 2], [3., 4.], [5., 6.], [7., 8.]]
*
*
* @param startDim the first dim to flatten, inclusive
* @param endDim the last dim to flatten, inclusive
* @return a partially fallen {@code NDArray}
*/
NDArray flatten(int startDim, int endDim);
/**
* Computes the one-dimensional discrete Fourier Transform.
*
* @param length Length of the transformed axis of the output.
* @return The truncated or zero-padded input, transformed along the axis indicated by axis, or
* the last one if axis is not specified.
*/
default NDArray fft(long length) {
return fft(length, -1);
}
/**
* Computes the one-dimensional discrete Fourier Transform.
*
* @param length Length of the transformed axis of the output.
* @param axis Axis over which to compute the FFT.
* @return The truncated or zero-padded input, transformed along the axis indicated by axis, or
* the last one if axis is not specified.
*/
NDArray fft(long length, long axis);
/**
* Computes the one dimensional inverse discrete Fourier transform.
*
* @param length Length of the transformed axis of the output.
* @param axis Axis over which to compute the IFFT.
* @return The truncated or zero-padded input, transformed along the axis indicated by axis, or
* the last one if axis is not specified.
*/
NDArray ifft(long length, long axis);
/**
* Computes the one dimensional inverse discrete Fourier transform.
*
* @param length Length of the transformed axis of the output.
* @return The truncated or zero-padded input, transformed along the axis indicated by axis, or
* the last one if axis is not specified.
*/
default NDArray ifft(long length) {
return ifft(length, -1);
}
/**
* Computes the one dimensional Fourier transform of real-valued input.
*
* @param length Length of the transformed axis of the output.
* @return The truncated or zero-padded input, transformed along the axis indicated by axis, or
* the last one if axis is not specified.
*/
default NDArray rfft(long length) {
return rfft(length, -1);
}
/**
* Computes the one dimensional Fourier transform of real-valued input.
*
* @param length Length of the transformed axis of the output.
* @param axis Axis over which to compute the FFT.
* @return The truncated or transformed along the axis indicated by axis, or the last one if
* axis is not specified.
*/
NDArray rfft(long length, long axis);
/**
* Computes the one dimensional inverse Fourier transform of real-valued input.
*
* @param length Length of the transformed axis of the output.
* @param axis Axis over which to compute the IRFFT.
* @return The truncated or transformed along the axis indicated by axis, or the last one if
* axis is not specified.
*/
NDArray irfft(long length, long axis);
/**
* Computes the one dimensional inverse Fourier transform of real-valued input.
*
* @param length Length of the transformed axis of the output.
* @return The truncated or transformed along the axis indicated by axis, or the last one if
* axis is not specified.
*/
default NDArray irfft(long length) {
return irfft(length, -1);
}
/**
* Computes the Short Time Fourier Transform (STFT).
*
* @param nFft size of Fourier transform
* @param hopLength the distance between neighboring sliding window frames. Default can be:
* floor(n_fft / 4)
* @param center whether to pad input on both sides.
* @param window Desired window to use. Recommend for HanningWindow
* @param returnComplex whether to return a complex tensor, or a real tensor with an extra last
* dimension for the real and imaginary components.
* @return A NDArray containing the STFT result with shape described above
*/
default NDArray stft(
long nFft, long hopLength, boolean center, NDArray window, boolean returnComplex) {
return stft(nFft, hopLength, center, window, false, returnComplex);
}
/**
* Computes the Short Time Fourier Transform (STFT).
*
* @param nFft size of Fourier transform
* @param hopLength the distance between neighboring sliding window frames. Default can be:
* floor(n_fft / 4)
* @param center whether to pad input on both sides.
* @param window Desired window to use. Recommend for HanningWindow
* @param normalize controls whether to return the normalized STFT results
* @param returnComplex whether to return a complex tensor, or a real tensor with an extra last
* dimension for the real and imaginary components.
* @return A NDArray containing the STFT result with shape described above
*/
NDArray stft(
long nFft,
long hopLength,
boolean center,
NDArray window,
boolean normalize,
boolean returnComplex);
/**
* Computes the two-dimensional Discrete Fourier Transform.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @param axes Axes over which to compute the 2D-FFT.
* @return The truncated or zero-padded input, transformed along the axes.
*/
NDArray fft2(long[] sizes, long[] axes);
/**
* Computes the two-dimensional Discrete Fourier Transform along the last 2 axes.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @return The truncated or zero-padded input, transformed along the last two axes
*/
default NDArray fft2(long[] sizes) {
return fft2(sizes, new long[] {-2, -1});
}
/**
* Computes the two-dimensional inverse Discrete Fourier Transform.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @param axes Axes over which to compute the 2D-Inverse-FFT.
* @return The truncated or zero-padded input, transformed along the axes.
*/
NDArray ifft2(long[] sizes, long[] axes);
/**
* Computes the two-dimensional inverse Discrete Fourier Transform along the last 2 axes.
*
* @param sizes Sizes of the transformed axes of the output. Will be zero-padded or trimmed to
* this size.
* @return The truncated or zero-padded input, transformed along the axes.
*/
default NDArray ifft2(long[] sizes) {
return ifft2(sizes, new long[] {-2, -1});
}
/**
* Pads this {@code NDArray} with the given {@link Shape}.
*
* Examples
*
*
* NDArray array = manager.zeros(3, 3, 4, 2);
* array.pad(new Shape(1, 1), 0); # pad last dim by 1 on each side
* array.getShape() => (3, 3, 4, 4)
*
*
* @param padding the padding shape, must be even
* @return a padded {@code NDArray}
* @throws IllegalArgumentException thrown if the given {@code padding} does not match the size
* of the current shape
*/
NDArray pad(Shape padding, double value);
/**
* Reshapes this {@code NDArray} to the given {@link Shape}.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(6f);
* jshell> array;
* ND: (6) cpu() float32
* [0., 1., 2., 3., 4., 5.]
* jshell> array.reshape(2, 3);
* ND: (2, 3) cpu() float32
* [[0., 1., 2.],
* [3., 4., 5.],
* ]
*
*
* @param newShape the long array to reshape into. Must have equal size to the current shape
* @return a reshaped {@code NDArray}
* @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of
* the current shape
*/
default NDArray reshape(long... newShape) {
return reshape(new Shape(newShape));
}
/**
* Reshapes this {@code NDArray} to the given {@link Shape}.
*
* You can reshape it to match another NDArray by calling {@code a.reshape(b.getShape()) }
*
*
Examples
*
*
* jshell> NDArray array = manager.arange(6f);
* jshell> array;
* ND: (6) cpu() float32
* [0., 1., 2., 3., 4., 5.]
* jshell> array.reshape(new Shape(2, 3));
* ND: (2, 3) cpu() float32
* [[0., 1., 2.],
* [3., 4., 5.],
* ]
*
*
* @param shape the {@link Shape} to reshape into. Must have equal size to the current shape
* @return a reshaped {@code NDArray}
* @throws IllegalArgumentException thrown if the given {@link Shape} does not match the size of
* the current shape
*/
NDArray reshape(Shape shape);
/**
* Expands the {@link Shape} of a {@code NDArray}.
*
* Inserts a new axis that will appear at the axis position in the expanded {@code NDArray}
* shape.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f});
* jshell> array;
* ND: (2) cpu() float32
* [1., 2.]
* jshell> array.expandDims(0);
* ND: (1, 2) cpu() float32
* [[1., 2.],
* ]
* jshell> array.expandDims(1);
* ND: (2, 1) cpu() float32
* [[1.],
* [2.],
* ]
*
*
* @param axis the position in the expanded axes where the new axis is placed
* @return the result {@code NDArray}. The number of dimensions is one greater than that of the
* {@code NDArray}
*/
NDArray expandDims(int axis);
/**
* Removes all singleton dimensions from this {@code NDArray} {@link Shape}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f}, new Shape(1, 3, 1));
* jshell> array;
* ND: (1, 3, 1) cpu() float32
* [[[0.],
* [1.],
* [2.],
* ],
* ]
* jshell> array.squeeze();
* ND: (3) cpu() float32
* [0., 1., 2.]
*
*
* @return a result {@code NDArray} of same size and data without singleton dimensions
*/
default NDArray squeeze() {
long[] shape = getShape().getShape();
return squeeze(IntStream.range(0, shape.length).filter(i -> shape[i] == 1).toArray());
}
/**
* Removes a singleton dimension at the given axis.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f}, new Shape(1, 3, 1));
* jshell> array;
* ND: (1, 3, 1) cpu() float32
* [[[0.],
* [1.],
* [2.],
* ],
* ]
* jshell> array.squeeze(0);
* ND: (3, 1) cpu() float32
* [[0.],
* [1.],
* [2.],
* ]
* jshell> array.squeeze(2);
* ND: (1, 3) cpu() float32
* [[0., 1., 2.],
* ]
*
*
* @param axis the axis at which to remove the singleton dimension
* @return a result {@code NDArray} of same size and data without the axis at part of the shape
* @throws IllegalArgumentException thrown if the given axis is not a singleton dimension
*/
default NDArray squeeze(int axis) {
return squeeze(new int[] {axis});
}
/**
* Removes singleton dimensions at the given axes.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f}, new Shape(1, 3, 1));
* jshell> array;
* ND: (1, 3, 1) cpu() float32
* [[[0.],
* [1.],
* [2.],
* ],
* ]
* jshell> array.squeeze(new int[] {0, 2});
* ND: (3) cpu() float32
* [0., 1., 2.]
*
*
* @param axes the axes at which to remove the singleton dimensions
* @return a result {@code NDArray} of same size and data without the axes at part of the shape
* @throws IllegalArgumentException thrown if any of the given axes are not a singleton
* dimension
*/
NDArray squeeze(int[] axes);
/**
* Returns the unique elements of the input tensor.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {3f, 1f, 2f, 3f, 1f, 2f, 1f, 3f, 2f}, new Shape(3, 3));
* jshell> array;
* ND: (3, 3) cpu() float32
* [[[3., 1., 2.],
* [3., 1., 2.],
* [1., 3., 2.],
* ],
* ]
* jshell> NDList output = array.unique(0, true, true, true);
* jshell> output.get(0);
* jshell> output.get(1);
* jshell> output.get(2);
*
* ND: (2, 3) cpu() float32
* [[1., 3., 2.],
* [3., 1., 2.],
* ]
*
* ND: (3) cpu() int64
* [ 1, 1, 0]
*
* ND: (2) cpu() int64
* [ 1, 2]
*
*
*
* @param dim the dimension to apply unique
* @param sorted whether to sort the unique elements in ascending order before returning as
* output
* @param returnInverse return the indices which, fed into the output unique array as indices,
* will recover the original array
* @param returnCounts return the counts for each unique element
* @return An {@code NDList} containing: output (Tensor): the output list of unique elements or
* low-rank tensors. inverse_indices (Tensor): (optional) if return_inverse is True, there
* will be an additional returned tensor (same shape as input) representing the indices for
* where elements in the original input map to in the output; otherwise, this function will
* only return a single tensor. counts (Tensor): (optional) if return_counts is True, there
* will be an additional returned tensor (same shape as output or output.size(dim), if dim
* was specified) representing the number of occurrences for each unique value or tensor.
*/
NDList unique(Integer dim, boolean sorted, boolean returnInverse, boolean returnCounts);
/**
* Returns the unique elements of the input tensor. The output is flattened.
*
* @param sorted whether to sort the unique elements in ascending order before returning as
* output
* @param returnInverse return the indices which, fed into the output unique array as indices,
* will recover the original array
* @param returnCounts return the counts for each unique element
* @return An {@code NDList} containing: output (Tensor): the output list of unique elements or
* low-rank tensors. inverse_indices (Tensor): (optional) if return_inverse is True, there
* will be an additional returned tensor (same shape as input) representing the indices for
* where elements in the original input map to in the output; otherwise, this function will
* only return a single tensor. counts (Tensor): (optional) if return_counts is True, there
* will be an additional returned tensor (same shape as output or output.size(dim), if dim
* was specified) representing the number of occurrences for each unique value or tensor.
*/
default NDList unique(boolean sorted, boolean returnInverse, boolean returnCounts) {
return unique(null, sorted, returnInverse, returnCounts);
}
/**
* Returns the unique elements of the input tensor. The output is flattened.
*
* @return An {@code NDList} containing: output (Tensor): the output list of unique elements or
* low-rank tensors. inverse_indices (Tensor): (optional) if return_inverse is True, there
* will be an additional returned tensor (same shape as input) representing the indices for
* where elements in the original input map to in the output; otherwise, this function will
* only return a single tensor. counts (Tensor): (optional) if return_counts is True, there
* will be an additional returned tensor (same shape as output or output.size(dim), if dim
* was specified) representing the number of occurrences for each unique value or tensor.
*/
default NDList unique() {
return unique(null, true, false, false);
}
/**
* Joins a {@code NDArray} along the first axis.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {0f, 1f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 3f});
* jshell> array1.stack(array2)
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [2., 3.],
* ]
*
*
* @param array the input {@code NDArray} which must have the same {@link Shape}as this {@code
* NDArray}
* @return the result {@code NDArray}. The stacked {@code NDArray} has one more dimension than
* the input {@code NDArray}.
*/
default NDArray stack(NDArray array) {
return stack(array, 0);
}
/**
* Joins a {@code NDArray} along a new axis.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {0f, 1f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 3f});
* jshell> array1.stack(array2, 0);
* ND: (2, 2) cpu() float32
* [[0., 1.],
* [2., 3.],
* ]
* jshell> array1.stack(array2, 1);
* ND: (2, 2) cpu() float32
* [[0., 2.],
* [1., 3.],
* ]
*
*
* @param array the input {@code NDArray} which must have the same {@link Shape}as this {@code
* NDArray}
* @param axis the axis in the result {@code NDArray} along which the input {@code NDArray} are
* stacked
* @return the result {@code NDArray}. The stacked {@code NDArray} has one more dimension than
* the input {@code NDArray}.
*/
default NDArray stack(NDArray array, int axis) {
return getNDArrayInternal().stack(new NDList(array), axis);
}
/**
* Joins a {@code NDArray} along the first axis.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {0f, 1f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 3f});
* jshell> array1.concat(array2)
* ND: (4) cpu() float32
* [0., 1., 2., 3.]
*
*
* @param array a {@code NDArray} which have the same {@link Shape}as this {@code NDArray},
* except in the dimension corresponding to axis
* @return the concatenated {@code NDArray}
*/
default NDArray concat(NDArray array) {
return concat(array, 0);
}
/**
* Joins a {@code NDArray} along an existing axis.
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {0f, 1f});
* jshell> NDArray array2 = manager.create(new float[] {2f, 3f});
* jshell> array1.concat(array2, 0);
* ND: (4) cpu() float32
* [0., 1., 2., 3.]
*
*
* @param array a {@code NDArray} which have the same {@link Shape}as this {@code NDArray},
* except in the dimension corresponding to axis
* @param axis the axis along which this {@code NDArray} will be joined
* @return the concatenated {@code NDArray}
*/
default NDArray concat(NDArray array, int axis) {
return getNDArrayInternal().concat(new NDList(array), axis);
}
////////////////////////////////////////
// Operations: Logical Op
////////////////////////////////////////
/**
* Returns the truth value of this {@code NDArray} AND the other {@code NDArray} element-wise.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.create(new boolean[] {true});
* jshell> NDArray array2 = manager.create(new boolean[] {false});
* jshell> array1.logicalAnd(array2);
* ND: (1) cpu() boolean
* [false]
* jshell> array1 = manager.create(new boolean[] {true, false});
* jshell> array2 = manager.create(new boolean[] {false, false});
* jshell> array1.logicalAnd(array2);
* ND: (2) cpu() boolean
* [false, false]
*
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.gt(1).logicalAnd(array.lt(4));
* ND: (5) cpu() boolean
* [false, false, true, true, false]
*
*
* @param other the other {@code NDArray} to operate on
* @return the boolean {@code NDArray} of the logical AND operation applied to the elements of
* this {@code NDArray} and the other {@code NDArray}
*/
NDArray logicalAnd(NDArray other);
/**
* Computes the truth value of this {@code NDArray} OR the other {@code NDArray} element-wise.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array1 = manager.create(new boolean[] {true});
* jshell> NDArray array2 = manager.create(new boolean[] {false});
* jshell> array1.logicalOr(array2);
* ND: (1) cpu() boolean
* [ true]
* jshell> array1 = manager.create(new boolean[] {true, false});
* jshell> array2 = manager.create(new boolean[] {false, false});
* jshell> array1.logicalOr(array2);
* ND: (2) cpu() boolean
* [ true, false]
*
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.lt(1).logicalOr(array.gt(3));
* ND: (5) cpu() boolean
* [ true, false, false, false, true]
*
*
* @param other the other {@code NDArray} to operate on
* @return the boolean {@code NDArray} of the logical OR operation applied to the elements of
* this {@code NDArray} and the other {@code NDArray}
*/
NDArray logicalOr(NDArray other);
/**
* Computes the truth value of this {@code NDArray} XOR the other {@code NDArray} element-wise.
*
* The shapes of this {@code NDArray} and the other {@code NDArray} must be broadcastable.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new boolean[] {true});
* jshell> array1.logicalXor(array2);
* ND: (1) cpu() boolean
* [ true]
* jshell> array1 = manager.create(new boolean[] {true, false});
* jshell> array2 = manager.create(new boolean[] {false, false});
* jshell> array1.logicalXor(array2);
* ND: (2) cpu() boolean
* [ true, false]
*
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.lt(1).logicalXor(array.gt(3));
* ND: (5) cpu() boolean
* [ true, false, false, false, true]
*
*
* @param other the other {@code NDArray} to operate on
* @return the boolean {@code NDArray} of the logical XOR operation applied to the elements of
* this {@code NDArray} and the other {@code NDArray}
*/
NDArray logicalXor(NDArray other);
/**
* Computes the truth value of NOT this {@code NDArray} element-wise.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new boolean[] {true});
* jshell> array.logicalNot();
* ND: (1) cpu() boolean
* [ false]
*
*
*
* jshell> NDArray array = manager.arange(5f);
* jshell> array.lt(1).logicalNot();
* ND: (5) cpu() boolean
* [false, true, true, true, true]
*
*
* @return the boolean {@code NDArray}
*/
NDArray logicalNot();
////////////////////////////////////////
// Operations: Other
////////////////////////////////////////
/**
* Returns the indices that would sort this {@code NDArray}.
*
* Perform an indirect sort along the given axis. It returns a {@code NDArray} of indices of
* the same {@link Shape} as this {@code NDArray}.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {3f, 1f, 2f});
* jshell> array.argSort();
* ND: (3) cpu() int64
* [ 1, 2, 0]
*
* jshell> array = manager.create(new float[] {0f, 3f, 2f, 2f}, new Shape(2, 2));
* jshell> array.argSort();
* ND: (2, 2) cpu() int64
* [[ 0, 1],
* [ 0, 1],
* ]
*
*
* @return a {@code NDArray} of indices corresponding to elements in this {@code NDArray} on the
* axis, the output DataType is always {@link DataType#INT64}
* @see NDArray#argSort(int, boolean)
*/
default NDArray argSort() {
return argSort(-1, true);
}
/**
* Returns the indices that would sort this {@code NDArray} given the axis.
*
* Perform an indirect sort along the given axis. It returns a {@code NDArray} of indices of
* the same {@link Shape} as this {@code NDArray}.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 3f, 2f, 2f}, new Shape(2, 2));
* jshell> array.argSort(0);
* ND: (2, 2) cpu() int64
* [[ 0, 1],
* [ 1, 0],
* ]
* jshell> array.argSort(1);
* ND: (2, 2) cpu() int64
* [[ 0, 1],
* [ 0, 1],
* ]
*
*
* @param axis the axis to sort along
* @return a {@code NDArray} of indices corresponding to elements in this {@code NDArray} on the
* axis, the output DataType is always {@link DataType#INT64}
* @see NDArray#argSort(int, boolean)
*/
default NDArray argSort(int axis) {
return argSort(axis, true);
}
/**
* Returns the indices that would sort this {@code NDArray} given the axis.
*
* Perform an indirect sort along the given axis. It returns a {@code NDArray} of indices of
* the same {@link Shape} as this {@code NDArray}.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 3f, 2f, 2f}, new Shape(2, 2));
* jshell> array.argSort(0, false);
* ND: (2, 2) cpu() int64
* [[ 1, 0],
* [ 0, 1],
* ]
*
*
* @param axis the axis to sort along
* @param ascending whether to sort ascending
* @return a {@code NDArray} of indices corresponding to elements in this {@code NDArray} on the
* axis, the output DataType is always {@link DataType#INT64}
*/
NDArray argSort(int axis, boolean ascending);
/**
* Sorts the flattened {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 4f, 3f, 1f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 4.],
* [3., 1.],
* ]
* jshell> array.sort(); // sort the flattened array
* ND: (2, 2) cpu() float32
* [[1., 4.],
* [1., 3.],
* ]
*
*
* @return the sorted {@code NDArray}
*/
NDArray sort();
/**
* Sorts the flattened {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 4f, 3f, 1f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 4.],
* [3., 1.],
* ]
* jshell> array.sort(0); // sort along the first axis
* ND: (2, 2) cpu() float32
* [[1., 1.],
* [3., 4.],
* ]
*
*
* @param axis the axis to sort along
* @return the sorted {@code NDArray}
*/
NDArray sort(int axis);
/**
* Applies the softmax function along the given axis.
*
* @param axis the axis along which to apply
* @return the result {@code NDArray}
* @see softmax
* @see NDArray#softmax(int)
*/
NDArray softmax(int axis);
/**
* Applies the softmax function followed by a logarithm.
*
* Mathematically equivalent to calling softmax and then log. This single operator is faster
* than calling two operators and numerically more stable when computing gradients.
*
* @param axis the axis along which to apply
* @return the result {@code NDArray}
*/
NDArray logSoftmax(int axis);
/**
* Returns the cumulative sum of the elements in the flattened {@code NDArray}.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f}, new Shape(2, 3));
* jshell> array;
* ND: (2, 3) cpu() float32
* [[1., 2., 3.],
* [4., 5., 6.],
* ]
* jshell> array.cumSum(); // cumSum on flattened array
* ND: (6) cpu() float32
* [ 1., 3., 6., 10., 15., 21.]
*
*
* @return the cumulative sum of the elements in the flattened {@code NDArray}
*/
NDArray cumSum();
/**
* Return the cumulative sum of the elements along a given axis.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f}, new Shape(2, 3));
* jshell> array;
* ND: (2, 3) cpu() float32
* [[1., 2., 3.],
* [4., 5., 6.],
* ]
* jshell> array.cumSum(0);
* ND: (2, 3) cpu() float32
* [[1., 2., 3.],
* [5., 7., 9.],
* ]
* jshell> array.cumSum(1);
* ND: (2, 3) cpu() float32
* [[ 1., 3., 6.],
* [ 4., 9., 15.],
* ]
*
*
* @param axis the axis along which the cumulative sum is computed
* @return the cumulative sum along the specified axis
*/
NDArray cumSum(int axis);
/**
* Replace the handle of the NDArray with the other. The NDArray used for replacement will be
* killed.
*
* Please use with caution, this method will make the input argument unusable.
*
* @param replaced the handle provider that will be killed
*/
void intern(NDArray replaced);
/**
* Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s
* entries are infinite, or {@code false} where they are not infinite.
*
* @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s entries
* are infinite
*/
NDArray isInfinite();
/**
* Computes the inverse of square {@code NDArray} if it exists.
*
* @return the inverse of square {@code NDArray}.
*/
NDArray inverse();
/**
* Returns the boolean {@code NDArray} with value {@code true} where this {@code NDArray}'s
* entries are NaN, or {@code false} where they are not NaN.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {Float.POSITIVE_INFINITY, 0, Float.NaN});
* jshell> array.isNaN();
* ND: (3) cpu() boolean
* [false, false, true]
*
*
* @return the boolean {@code NDArray} with value {@code true} if this {@code NDArray}'s {@link
* NDArray} are NaN
*/
NDArray isNaN();
/**
* Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given
* repeats.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f});
* jshell> array.tile(2);
* ND: (6) cpu() float32
* [0., 1., 2., 0., 1., 2.]
*
*
* @param repeats the number of times to repeat for each dimension
* @return a NDArray that has been tiled
*/
NDArray tile(long repeats);
/**
* Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by
* repeats along given axis.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f});
* jshell> array.tile(1, 2);
* ND: (1, 6) cpu() float32
* [[0., 1., 2., 0., 1., 2.],
* ]
*
*
* @param axis the axis to repeat
* @param repeats the number of times to repeat for each axis
* @return a {@code NDArray} that has been tiled
* @throws IllegalArgumentException thrown for invalid axis
*/
NDArray tile(int axis, long repeats);
/**
* Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times given by
* repeats.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f});
* jshell> array.tile(new long[] {2, 2});
* ND: (2, 6) cpu() float32
* [[0., 1., 2., 0., 1., 2.],
* [0., 1., 2., 0., 1., 2.],
* ]
*
*
* @param repeats the number of times to repeat along each axis
* @return a {@code NDArray} that has been tiled
*/
NDArray tile(long[] repeats);
/**
* Constructs a {@code NDArray} by repeating this {@code NDArray} the number of times to match
* the desired shape.
*
* If the desired {@link Shape}has fewer dimensions than this {@code NDArray}, it will tile
* against the last axis.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f});
* jshell> array.tile(new Shape(6));
* ND: (6) cpu() float32
* [0., 1., 2., 0., 1., 2.]
*
*
* @param desiredShape the {@link Shape}that should be converted to
* @return a {@code NDArray} that has been tiled
*/
NDArray tile(Shape desiredShape);
/**
* Repeats element of this {@code NDArray} the number of times given repeats.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f});
* jshell> array.repeat(2);
* ND: (6) cpu() float32
* [0., 0., 1., 1., 2., 2.]
*
*
* @param repeats the number of times to repeat for each axis
* @return an {@code NDArray} that has been repeated
*/
NDArray repeat(long repeats);
/**
* Repeats element of this {@code NDArray} the number of times given repeats along given axis.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f, 3f}, new Shape(2, 2));
* jshell> array.repeat(1, 2);
* ND: (6) cpu() float32
* [[0., 0., 1., 1.],
* [2., 2., 3., 3.]]
*
*
* @param axis the axis to repeat
* @param repeats the number of times to repeat for each axis
* @return an {@code NDArray} that has been repeated
* @throws IllegalArgumentException thrown for invalid axis
*/
NDArray repeat(int axis, long repeats);
/**
* Repeats element of this {@code NDArray} the number of times given repeats along each axis.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f, 3f}, new Shape(2, 2));
* jshell> array.repeat(new long[] {2, 2});
* ND: (12) cpu() float32
* [0., 0., 0., 0., 1., 1., 1., 1., 2., 2., 2., 2.]
*
*
* @param repeats the number of times to repeat along each axis
* @return a {@code NDArray} that has been repeated
*/
NDArray repeat(long[] repeats);
/**
* Repeats element of this {@code NDArray} to match the desired shape.
*
* If the desired {@link Shape} has fewer dimensions that the array, it will repeat against
* the last axis.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 1f, 2f, 3f}, new Shape(2, 2));
* jshell> array.repeat(new Shape(4, 4));
* ND: (4, 4) cpu() float32
* [[0., 0., 1., 1.],
* [0., 0., 1., 1.],
* [2., 2., 3., 3.],
* [2., 2., 3., 3.],
* ]
*
*
* @param desiredShape the {@link Shape} that should be converted to
* @return an {@code NDArray} that has been repeated
*/
NDArray repeat(Shape desiredShape);
/**
* Dot product of this {@code NDArray} and the other {@code NDArray}.
*
*
* - If both this {@code NDArray} and the other {@code NDArray} are 1-D {@code NDArray}s, it
* is inner product of vectors (without complex conjugation).
*
- If both this {@code NDArray} and the other {@code NDArray} are 2-D {@code NDArray}s, it
* is matrix multiplication.
*
- If either this {@code NDArray} or the other {@code NDArray} is 0-D {@code NDArray}
* (scalar), it is equivalent to mul.
*
- If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is 1-D
* {@code NDArray}, it is a sum product over the last axis of those.
*
- If this {@code NDArray} is N-D {@code NDArray} and the other {@code NDArray} is M-D
* {@code NDArray}(where M>=2), it is a sum product over the last axis of this
* {@code NDArray} and the second-to-last axis of the other {@code NDArray}
*
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {1f, 2f, 3f});
* jshell> NDArray array2 = manager.create(new float[] {4f, 5f, 6f});
* jshell> array1.dot(array2); // inner product
* ND: () cpu() float32
* 32.
* jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array2 = manager.create(new float[] {5f, 6f, 7f, 8f}, new Shape(2, 2));
* jshell> array1.dot(array2); // matrix multiplication
* ND: (2, 2) cpu() float32
* [[19., 22.],
* [43., 50.],
* ]
* jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array2 = manager.create(5f);
* jshell> array1.dot(array2);
* ND: (2, 2) cpu() float32
* [[ 5., 10.],
* [15., 20.],
* ]
* jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array2 = manager.create(new float[] {1f, 2f});
* jshell> array1.dot(array2);
* ND: (2) cpu() float32
* [ 5., 11.]
* jshell> array1 = manager.create(new float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f}, new Shape(2, 2, 2));
* jshell> array2 = manager.create(new float[] {1f, 2f, 3f ,4f}, new Shape(2, 2));
* jshell> array1.dot(array2);
* ND: (2, 2, 2) cpu() float32
* [[[ 7., 10.],
* [15., 22.],
* ],
* [[23., 34.],
* [31., 46.],
* ],
* ]
*
*
* @param other the other {@code NDArray} to perform dot product with
* @return the result {@code NDArray}
*/
NDArray dot(NDArray other);
/**
* Product matrix of this {@code NDArray} and the other {@code NDArray}.
*
* The behavior depends on the arguments in the following way.
*
*
* - If both this {@code NDArray} and the other {@code NDArray} are 2-D {@code NDArray}s,
* they are multiplied like conventional matrices
*
- If either this {@code NDArray} or the other {@code NDArray} is N-D {@code NDArray}, N
* > 2 , it is treated as a stack of matrices residing in the last two indexes and
* broadcast accordingly.
*
- If this {@code NDArray} is 1-D {@code NDArray}, it is promoted to a matrix by
* prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is
* removed.
*
- If other {@code NDArray} is 1-D {@code NDArray}, it is promoted to a matrix by
* appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
*
*
* Examples
*
*
* jshell> NDArray array1 = manager.create(new float[] {1f, 0f, 0f, 1f}, new Shape(2, 2));
* jshell> NDArray array2 = manager.create(new float[] {4f, 1f, 2f, 2f}, new Shape(2, 2));
* jshell> array1.matMul(array2); // for 2-D arrays, it is the matrix product
* ND: (2, 2) cpu() float32
* [[4., 1.],
* [2., 2.],
* ]
* jshell> array1 = manager.create(new float[] {1f, 0f, 0f, 1f}, new Shape(2, 2));
* jshell> array2 = manager.create(new float[] {1f, 2f});
* jshell> array1.matMul(array2);
* ND: (2) cpu() float32
* [1., 2.]
* jshell> array1 = manager.create(new float[] {1f, 0f, 0f, 1f}, new Shape(2, 2));
* jshell> array2 = manager.create(new float[] {1f, 2f});
* jshell> array1.matMul(array2);
* ND: (2) cpu() float32
* [1., 2.]
* jshell> array1 = manager.arange(2f * 2f * 4f).reshape(2, 2, 4);
* jshell> array2 = manager.arange(2f * 2f * 4f).reshape(2, 4, 2);
* jshell> array1.matMul(array2).get("0, 1, 1");
* ND: () cpu() float32
* 98.
*
*
* @param other the other {@code NDArray} to perform matrix product with
* @return the result {@code NDArray}
*/
NDArray matMul(NDArray other);
/**
* Batch product matrix of this {@code NDArray} and the other {@code NDArray}.
*
* @param other the other {@code NDArray} to perform matrix product with
* @return the result {@code NDArray}
*/
NDArray batchMatMul(NDArray other);
/**
* Clips (limit) the values in this {@code NDArray}.
*
* Given an interval, values outside the interval are clipped to the interval edges. For
* example, if an interval of [0, 1] is specified, values smaller than 0 become 0, and values
* larger than 1 become 1.
*
*
Examples
*
*
* jshell> NDArray array = manager.arange(10f);
* jshell> array.clip(1, 8);
* ND: (10) cpu() float32
* [1., 1., 2., 3., 4., 5., 6., 7., 8., 8.]
*
*
* @param min the minimum value
* @param max the maximum value
* @return an {@code NDArray} with the elements of this {@code NDArray}, but where values <
* min are replaced with min, and those > max with max
*/
NDArray clip(Number min, Number max);
/**
* Interchanges two axes of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f ,3f}, new Shape(1, 3));
* jshell> array;
* ND: (1, 3) cpu() float32
* [[1., 2., 3.],
* ]
* jshell> array.swapAxes(0, 1);
* ND: (3, 1) cpu() float32
* [[1.],
* [2.],
* [3.],
* ]
*
*
* @param axis1 the first axis
* @param axis2 the second axis
* @return the swapped axes {@code NDArray}
*/
default NDArray swapAxes(int axis1, int axis2) {
int[] dims = IntStream.range(0, getShape().dimension()).toArray();
int tmp = dims[axis1];
dims[axis1] = dims[axis2];
dims[axis2] = tmp;
return transpose(dims);
}
/**
* Returns the reverse order of elements in an array along the given axis.
*
* The shape of the array is preserved, but the elements are reordered.
*
* @param axes the axes to flip on
* @return the newly flipped array
*/
NDArray flip(int... axes);
/**
* Returns this {@code NDArray} with axes transposed.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f ,3f, 4f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2.],
* [3., 4.],
* ]
* jshell> array.transpose();
* ND: (2, 2) cpu() float32
* [[1., 3.],
* [2., 4.],
* ]
*
*
* @return the newly permuted array
*/
NDArray transpose();
/**
* Returns this {@code NDArray} with given axes transposed.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f ,3f, 4f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2.],
* [3., 4.],
* ]
* jshell> array.transpose(1, 0);
* ND: (2, 2) cpu() float32
* [[1., 3.],
* [2., 4.],
* ]
* jshell> array = manager.arange(8f).reshape(2, 2, 2);
* jshell> array;
* ND: (2, 2, 2) cpu() float32
* [[[0., 1.],
* [2., 3.],
* ],
* [[4., 5.],
* [6., 7.],
* ],
* ]
* jshell> array.transpose(1, 0, 2);
* ND: (2, 2, 2) cpu() float32
* [[[0., 1.],
* [4., 5.],
* ],
* [[2., 3.],
* [6., 7.],
* ],
* ]
*
*
* @param axes the axes to swap to
* @return the transposed {@code NDArray}
* @throws IllegalArgumentException thrown when passing a axis that is greater than the actual
* number of dimensions
*/
NDArray transpose(int... axes);
/**
* Broadcasts this {@code NDArray} to be the given shape.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f ,3f, 4f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2.],
* [3., 4.],
* ]
* jshell> array.broadcast(new Shape(2, 2, 2));
* ND: (2, 2, 2) cpu() float32
* [[[1., 2.],
* [3., 4.],
* ],
* [[1., 2.],
* [3., 4.],
* ],
* ]
*
*
* @param shape the new {@link Shape} of this {@code NDArray}
* @return the broadcasted {@code NDArray}
*/
NDArray broadcast(Shape shape);
/**
* Broadcasts this {@code NDArray} to be the given shape.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f ,3f, 4f}, new Shape(2, 2));
* jshell> array;
* ND: (2, 2) cpu() float32
* [[1., 2.],
* [3., 4.],
* ]
* jshell> array.broadcast(2, 2, 2);
* ND: (2, 2, 2) cpu() float32
* [[[1., 2.],
* [3., 4.],
* ],
* [[1., 2.],
* [3., 4.],
* ],
* ]
*
*
* @param shape the new {@link Shape} of this {@code NDArray}
* @return the broadcasted {@code NDArray}
*/
default NDArray broadcast(long... shape) {
return broadcast(new Shape(shape));
}
/**
* Returns the indices of the maximum values into the flattened {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(6f).reshape(2, 3);
* jshell> array;
* ND: (2, 3) cpu() float32
* [[0., 1., 2.],
* [3., 4., 5.],
* ]
* jshell> array.argMax();
* ND: () cpu() int64
* 5.
*
*
* @return a {@code NDArray} containing indices
*/
NDArray argMax();
/**
* Returns the indices of the maximum values along given axis.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(6f).reshape(2, 3);
* jshell> array;
* ND: (2, 3) cpu() float32
* [[0., 1., 2.],
* [3., 4., 5.],
* ]
* jshell> array.argMax(0);
* ND: (3) cpu() int64
* [1, 1, 1]
* jshell> array.argMax(1);
* ND: (2) cpu() int64
* [2, 2]
*
*
* @param axis the axis along which to find maximum values
* @return a {@code NDArray} containing indices
*/
NDArray argMax(int axis);
/**
* Returns (values, indices) of the top k values along given axis.
*
* @param k the number of returned values
* @param axis the axis to sort along, whose shape is reduced to k
* @return a {@code NDList} containing (values, indices)
*/
default NDList topK(int k, int axis) {
return topK(k, axis, true, true);
}
/**
* Returns (values, indices) of the top k values along given axis.
*
* @param k the number of returned values
* @param axis the axis to sort along, whose shape is reduced to k
* @param largest whether the largest or the smallest
* @param sorted whether the sorted or not
* @return a {@code NDList} containing (values, indices)
*/
NDList topK(int k, int axis, boolean largest, boolean sorted);
/**
* Returns the indices of the minimum values into the flattened {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(6f).reshape(2, 3);
* jshell> array;
* ND: (2, 3) cpu() float32
* [[0., 1., 2.],
* [3., 4., 5.],
* ]
* jshell> array.argMin();
* ND: () cpu() int64
* 0.
*
*
* @return a {@code NDArray} containing indices
*/
NDArray argMin();
/**
* Returns the indices of the minimum values along given axis.
*
* Examples
*
*
* jshell> NDArray array = manager.arange(6f).reshape(2, 3);
* jshell> array;
* ND: (2, 3) cpu() float32
* [[0., 1., 2.],
* [3., 4., 5.],
* ]
* jshell> array.argMin(0);
* ND: (3) cpu() int64
* [0, 0, 0]
* jshell> array.argMin(1);
* ND: (2) cpu() int64
* [0, 0]
*
*
* @param axis the axis along which to find minimum values
* @return a {@code NDArray} containing indices
*/
NDArray argMin(int axis);
/**
* Returns percentile for this {@code NDArray}.
*
* @param percentile the target percentile in range of 0..100
* @return the result {@code NDArray}
*/
NDArray percentile(Number percentile);
/**
* Returns median along given dimension(s).
*
* @param percentile the target percentile in range of 0..100
* @param axes the dimension to calculate percentile for
* @return the result {@code NDArray} NDArray
*/
NDArray percentile(Number percentile, int[] axes);
/**
* Returns median value for this {@code NDArray}.
*
* @return the median {@code NDArray}
*/
NDArray median();
/**
* Returns median value along given axes.
*
* @param axes the axes along which to perform the median operation
* @return the median {@code NDArray} along the specified axes
*/
NDArray median(int[] axes);
// ------------ Sparse methods ------------
/**
* Returns a dense representation of the sparse {@code NDArray}.
*
* @return the result {@code NDArray}
*/
NDArray toDense();
/**
* Returns a sparse representation of {@code NDArray}.
*
* @param fmt the {@link SparseFormat} of this {@code NDArray}
* @return the result {@code NDArray}
*/
NDArray toSparse(SparseFormat fmt);
/**
* Returns the indices of elements that are non-zero.
*
* Note that the behavior is slightly different from numpy.nonzero. Numpy returns a tuple of
* NDArray, one for each dimension of NDArray. DJL nonzero returns only one {@code NDArray} with
* last dimension containing all dimension of indices.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 1f, 1f, 0f, 1f});
* jshell> array.nonzero();
* ND: (4, 1) cpu() int64
* [[ 0],
* [ 1],
* [ 2],
* [ 4],
* ]
* jshell> array = manager.create(new float[] {3f, 0f, 0f, 0f, 4f, 0f, 5f, 6f, 0f}).reshape(3, 3);
* jshell> array;
* ND: (3, 3) cpu() float32
* [[3., 0., 0.],
* [0., 4., 0.],
* [5., 6., 0.],
* ]
* jshell> array.nonzero();
* ND: (4, 2) cpu() int64
* [[ 0, 0],
* [ 1, 1],
* [ 2, 0],
* [ 2, 1],
* ]
*
*
* @return the indices of the elements that are non-zero
*/
NDArray nonzero();
/**
* Returns {@code true} if this {@code NDArray} is special case: no-value {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new Shape(2, 0, 1));
* jshell> array;
* ND: (2, 0, 1) cpu() float32
* []
* jshell> array.isEmpty();
* true
*
*
* @return {@code true} if this NDArray is empty
*/
default boolean isEmpty() {
return getShape().size() == 0;
}
/**
* Returns {@code true} if all elements within this {@code NDArray} are non-zero or {@code
* true}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new boolean[] {true, false, true, true}, new Shape(2, 2));
* jshell> array.all();
* ND: () cpu() boolean
* false
* jshell> NDArray array = manager.create(new float[] {-1f, 4f, 5f});
* jshell> array.all(); // all elements are non-zero
* ND: () cpu() boolean
* true
*
*
* @return {@code true} if all elements within this {@code NDArray} are non-zero or {@code true}
*/
default NDArray all() {
// result of sum operation is int64 now
return toType(DataType.BOOLEAN, false).sum().eq(size());
}
/**
* Returns {@code true} if any of the elements within this {@code NDArray} are non-zero or
* {@code true}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new boolean[] {true, false, true, true}, new Shape(2, 2));
* jshell> array.any();
* ND: () cpu() boolean
* true
* jshell> NDArray array = manager.create(new float[] {-1, 0, 5});
* jshell> array.any() // all elements are non-zero
* ND: () cpu() boolean
* true
*
*
* @return {@code true} if any of the elements within this {@code NDArray} are non-zero or
* {@code true}
*/
default NDArray any() {
return toType(DataType.BOOLEAN, false).sum().gt(0);
}
/**
* Returns {@code true} if none of the elements within this {@code NDArray} are non-zero or
* {@code true}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new boolean[] {false, false});
* jshell> array.none();
* ND: () cpu() boolean
* true
* jshell> NDArray array = manager.create(new float[] {-1f, 0f, 5f});
* jshell> array.none() // all elements are non-zero
* ND: () cpu() boolean
* false
*
*
* @return {@code true} if none of the elements within this {@code NDArray} are non-zero or
* {@code true}
*/
default NDArray none() {
return toType(DataType.BOOLEAN, false).sum().eq(0);
}
/**
* Counts the number of non-zero values in this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 0f, 1f, 2f, 7f, 0f}, new Shape(2, 3));
* jshell> array.countNonzero()
* ND: () cpu() int64
* 3
*
*
* @return the number of non-zero values in this {@code NDArray}
*/
default NDArray countNonzero() {
return toType(DataType.BOOLEAN, false).sum();
}
/**
* Counts the number of non-zero values in this {@code NDArray} along a given axis.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 0f, 1f, 2f, 7f, 0f}, new Shape(2, 3));
* jshell> array;
* ND: (2, 3) cpu() float32
* [[0., 0., 1.],
* [2., 7., 0.],
* ]
* jshell> array.countNonzero(0);
* ND: (3) cpu() int64
* [ 1, 1, 1]
* jshell> array.countNonzero(1);
* ND: (2) cpu() int64
* [ 1, 2]
*
*
* @param axis the axis to operate on
* @return the number of non-zero values in this {@code NDArray} along a given axis
*/
default NDArray countNonzero(int axis) {
return toType(DataType.BOOLEAN, false).sum(new int[] {axis});
}
/**
* Returns element-wise inverse gauss error function of the {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 0.5f, -1f});
* jshell> array.erfinv();
* ND: (3) cpu() float32
* [0., 0.4769, -inf]
*
*
* @return The inverse of gauss error of the {@code NDArray}, element-wise
*/
NDArray erfinv();
/**
* Returns element-wise gauss error function of the {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
* jshell> array.erf();
* ND: (3) cpu() float32
* [0., 0.5, -1]
*
*
* @return The gauss error of the {@code NDArray}, element-wise
*/
NDArray erf();
/** {@inheritDoc} */
@Override
default List getResourceNDArrays() {
return Collections.singletonList(this);
}
/**
* Returns an internal representative of Native {@code NDArray}.
*
* This method should only be used by Engine provider
*
* @return an internal representative of Native {@code NDArray}
*/
NDArrayEx getNDArrayInternal();
/**
* Returns {@code true} if this NDArray has been released.
*
* @return {@code true} if this NDArray has been released
*/
boolean isReleased();
/**
* Runs the debug string representation of this {@code NDArray}.
*
* @return the debug string representation of this {@code NDArray}
*/
default String toDebugString() {
if (isReleased()) {
return "This array is already closed";
}
if (getDataType() == DataType.STRING) {
return Arrays.toString(toStringArray(StandardCharsets.UTF_8));
}
return NDFormat.format(this, 100, 10, 10, 20);
}
/**
* Runs the debug string representation of this {@code NDArray}.
*
* @param withContent true to show the content of NDArray
* @return the debug string representation of this {@code NDArray}
*/
default String toDebugString(boolean withContent) {
return toDebugString(1000, 10, 10, 20, withContent);
}
/**
* Runs the debug string representation of this {@code NDArray}.
*
* @param maxSize the maximum elements to print out
* @param maxDepth the maximum depth to print out
* @param maxRows the maximum rows to print out
* @param maxColumns the maximum columns to print out
* @param withContent true to show the content of NDArray
* @return the debug string representation of this {@code NDArray}
*/
default String toDebugString(
int maxSize, int maxDepth, int maxRows, int maxColumns, boolean withContent) {
if (isReleased()) {
return "This array is already closed";
}
if (getDataType() == DataType.STRING) {
return Arrays.toString(toStringArray(StandardCharsets.UTF_8));
}
return NDFormat.format(this, maxSize, maxDepth, maxRows, maxColumns, withContent);
}
/** {@inheritDoc} */
@Override
void close();
/**
* Returns the norm of this {@code NDArray}.
*
*
Examples
*
*
* jshell> NDArray array = manager.create(new float[] {-3f, -4f});
* jshell> array.norm();
* ND: () cpu() float32
* 5.
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.norm();
* ND: () cpu() float32
* 5.4472
*
*
* @return the norm of this {@code NDArray}
*/
default NDArray norm() {
return norm(false);
}
/**
* Returns the norm of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {-3f, -4f});
* jshell> array.norm(new int[] {0});
* ND: () cpu() float32
* 5.
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.norm(new int[] {0});
* ND: (2) cpu() float32
* [3.1623, 4.4721]
*
*
* @param axes If axes contains an integer, it specifies the axis of x along which to compute
* the vector norms. If axis contains 2 integers, it specifies the axes that hold 2-D
* matrices, and the matrix norms of these matrices are computed.
* @return the norm of this {@code NDArray}
*/
default NDArray norm(int[] axes) {
return norm(axes, false);
}
/**
* Returns the norm of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {-3f, -4f});
* jshell> array.norm(true);
* ND: () cpu() float32
* 5.
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.norm(true);
* ND: () cpu() float32
* [[5.4772],
* ]
*
*
* @param keepDims If this is set to True, the axes which are normed over are left in the result
* as dimensions with size one. With this option the result will broadcast correctly against
* the original x.
* @return the norm of this {@code NDArray}
*/
NDArray norm(boolean keepDims);
/**
* Returns the norm of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.norm(new int[] {0}, true);
* ND: (1, 2) cpu() float32
* [[3.1623, 4.4721],
* ]
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.norm(new int[] {0}, false);
* ND: (2) cpu() float32
* [3.1623, 4.4721]
*
*
* @param axes If axes contains an integer, it specifies the axis of x along which to compute
* the vector norms. If axis contains 2 integers, it specifies the axes that hold 2-D
* matrices, and the matrix norms of these matrices are computed.
* @param keepDims keepDims If this is set to True, the axes which are normed over are left in
* the result as dimensions with size one. With this option the result will broadcast
* correctly against the original x.
* @return the norm of this {@code NDArray}
*/
default NDArray norm(int[] axes, boolean keepDims) {
return norm(2, axes, keepDims);
}
/**
* Returns the norm of this {@code NDArray}.
*
* Examples
*
*
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.norm(2, new int[] {0}, true);
* ND: (1, 2) cpu() float32
* [[3.1623, 4.4721],
* ]
* jshell> NDArray array = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2));
* jshell> array.norm(2, new int[] {0}, false);
* ND: (2) cpu() float32
* [3.1623, 4.4721]
*
*
* @param ord Order of the norm.
* @param axes If axes contains an integer, it specifies the axis of x along which to compute
* the vector norms. If axis contains 2 integers, it specifies the axes that hold 2-D
* matrices, and the matrix norms of these matrices are computed.
* @param keepDims keepDims If this is set to True, the axes which are normed over are left in
* the result as dimensions with size one. With this option the result will broadcast
* correctly against the original x.
* @return the norm of this {@code NDArray}
*/
NDArray norm(int ord, int[] axes, boolean keepDims);
/**
* Returns a one-hot {@code NDArray}.
*
*
* - The locations represented by indices take value 1, while all other locations take value
* 0.
*
- If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is
* appended at the end.
*
- If {@code NDArray} is a scalar the output shape will be a vector of length depth.
*
- If {@code NDArray} is a vector of length features, the output shape will be features x
* depth.
*
- If {@code NDArray} is a matrix with shape [batch, features], the output shape will be
* batch x features x depth.
*
*
* Examples
*
*
* jshell> NDArray array = manager.create(new int[] {1, 0, 2, 0});
* jshell> array.oneHot(3);
* ND: (4, 3) cpu() float32
* [[0., 1., 0.],
* [1., 0., 0.],
* [0., 0., 1.],
* [1., 0., 0.],
* ]
* jshell> NDArray array = manager.create(new int[][] {{1, 0}, {1, 0}, {2, 0}});
* jshell> array.oneHot(3);
* ND: (3, 2, 3) cpu() float32
* [[[0., 1., 0.],
* [1., 0., 0.],
* ],
* [[0., 1., 0.],
* [1., 0., 0.],
* ],
* [[0., 0., 1.],
* [1., 0., 0.],
* ],
* ]
*
*
* @param depth Depth of the one hot dimension.
* @return one-hot encoding of this {@code NDArray}
* @see Classification-problems
*/
default NDArray oneHot(int depth) {
return oneHot(depth, 1f, 0f, DataType.FLOAT32);
}
/**
* Returns a one-hot {@code NDArray}.
*
*
* - The locations represented by indices take value 1, while all other locations take value
* 0.
*
- If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is
* appended at the end.
*
- If {@code NDArray} is a scalar the output shape will be a vector of length depth.
*
- If {@code NDArray} is a vector of length features, the output shape will be features x
* depth.
*
- If {@code NDArray} is a matrix with shape [batch, features], the output shape will be
* batch x features x depth.
*
*
* Examples
*
*
* jshell> NDArray array = manager.create(new int[] {1, 0, 2, 0});
* jshell> array.oneHot(3);
* ND: (4, 3) cpu() float32
* [[0., 1., 0.],
* [1., 0., 0.],
* [0., 0., 1.],
* [1., 0., 0.],
* ]
* jshell> NDArray array = manager.create(new int[][] {{1, 0}, {1, 0}, {2, 0}});
* jshell> array.oneHot(3);
* ND: (3, 2, 3) cpu() float32
* [[[0., 1., 0.],
* [1., 0., 0.],
* ],
* [[0., 1., 0.],
* [1., 0., 0.],
* ],
* [[0., 0., 1.],
* [1., 0., 0.],
* ],
* ]
*
*
* @param depth Depth of the one hot dimension.
* @param dataType dataType of the output.
* @return one-hot encoding of this {@code NDArray}
* @see Classification-problems
*/
default NDArray oneHot(int depth, DataType dataType) {
return oneHot(depth, 1f, 0f, dataType);
}
/**
* Returns a one-hot {@code NDArray}.
*
*
* - The locations represented by indices take value onValue, while all other locations take
* value offValue.
*
- If the input {@code NDArray} is rank N, the output will have rank N+1. The new axis is
* appended at the end.
*
- If {@code NDArray} is a scalar the output shape will be a vector of length depth.
*
- If {@code NDArray} is a vector of length features, the output shape will be features x
* depth.
*
- If {@code NDArray} is a matrix with shape [batch, features], the output shape will be
* batch x features x depth.
*
*
* Examples
*
*
* jshell> NDArray array = manager.create(new int[] {1, 0, 2, 0});
* jshell> array.oneHot(3, 8f, 1f, array.getDataType());
* ND: (4, 3) cpu() int32
* [[ 1, 8, 1],
* [ 8, 1, 1],
* [ 1, 1, 8],
* [ 8, 1, 1],
* ]
*
*
* @param depth Depth of the one hot dimension.
* @param onValue The value assigned to the locations represented by indices.
* @param offValue The value assigned to the locations not represented by indices.
* @param dataType dataType of the output.
* @return one-hot encoding of this {@code NDArray}
* @see Classification-problems
*/
NDArray oneHot(int depth, float onValue, float offValue, DataType dataType);
/**
* Batchwise product of this {@code NDArray} and the other {@code NDArray}.
*
*
* - batchDot is used to compute dot product of x and y when x and y are data in batch,
* namely N-D (N greater or equal to 3) arrays in shape of (B0, …, B_i, :, :). For
* example, given x with shape (B_0, …, B_i, N, M) and y with shape (B_0, …, B_i, M, K),
* the result array will have shape (B_0, …, B_i, N, K), which is computed by:
* batch_dot(x,y)[b_0, ..., b_i, :, :] = dot(x[b_0, ..., b_i, :, :], y[b_0, ..., b_i, :,
* :])
*
*
* Examples
*
*
* jshell> NDArray array1 = manager.ones(new Shape(2, 1, 4));
* jshell> NDArray array2 = manager.ones(new Shape(2, 4, 6));
* jshell> array1.batchDot(array2);
* ND: (2, 1, 6) cpu() float32
* [[[4., 4., 4., 4., 4., 4.],
* ],
* [[4., 4., 4., 4., 4., 4.],
* ],
* ]
*
*
* @param other the other {@code NDArray} to perform batch dot product with
* @return the result {@code NDArray}
*/
NDArray batchDot(NDArray other);
/**
* Convert a general NDArray to its complex math format.
*
* example: [10f, 12f] float32 -> [10+12j] in complex64
*
* @return the complex NDArray
*/
NDArray complex();
/**
* Convert a complex NDArray to its real math format. example: [10+12j] in complex64 -> [10f,
* 12f] float32
*
* @return tje real NDArray
*/
NDArray real();
/**
* Conjugate complex array.
*
* @return Returns a view of input with a flipped conjugate bit. If input has a non-complex
* type, this function just returns input.
*/
NDArray conj();
}