All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.engine.rust.RsNDArray Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.engine.rust;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.BaseNDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.util.NativeResource;

import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.stream.IntStream;

/** {@code RsNDArray} is the Rust implementation of {@link NDArray}. */
@SuppressWarnings("try")
public class RsNDArray extends NativeResource implements NDArray {

    private String name;
    private Device device;
    private DataType dataType;
    private Shape shape;
    private RsNDManager manager;
    private RsNDArrayEx ndArrayEx;

    // keep a reference to direct buffer to avoid GC release the memory
    @SuppressWarnings("PMD.UnusedPrivateField")
    private ByteBuffer dataRef;

    /**
     * Constructs a Rust {@code NDArray} from a native handle (internal. Use {@link NDManager}
     * instead).
     *
     * @param manager the manager to attach the new array to
     * @param handle the pointer to the native Rust memory
     */
    @SuppressWarnings("this-escape")
    public RsNDArray(RsNDManager manager, long handle) {
        this(manager, handle, null, null);
    }

    @SuppressWarnings("this-escape")
    RsNDArray(RsNDManager manager, long handle, DataType dataType) {
        this(manager, handle, dataType, null);
    }

    /**
     * Constructs a Rust {@code NDArray} from a native handle (internal. Use {@link NDManager}
     * instead) with the data that is hold on Java side.
     *
     * @param manager the manager to attach the new array to
     * @param handle the pointer to the native Rust memory
     * @param dataType the {@link DataType} to be set
     * @param data the direct buffer of the data
     */
    @SuppressWarnings("this-escape")
    public RsNDArray(RsNDManager manager, long handle, DataType dataType, ByteBuffer data) {
        super(handle);
        this.dataType = dataType;
        this.manager = manager;
        this.ndArrayEx = new RsNDArrayEx(this);
        dataRef = data;
        manager.attachInternal(getUid(), this);
        NDScope.register(this);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDManager getManager() {
        return manager;
    }

    /** {@inheritDoc} */
    @Override
    public String getName() {
        return name;
    }

    /** {@inheritDoc} */
    @Override
    public void setName(String name) {
        this.name = name;
    }

    /** {@inheritDoc} */
    @Override
    public DataType getDataType() {
        if (dataType == null) {
            int type = RustLibrary.getDataType(getHandle());
            dataType = DataType.values()[type];
        }
        return dataType;
    }

    /** {@inheritDoc} */
    @Override
    public Device getDevice() {
        if (device == null) {
            int[] dev = RustLibrary.getDevice(getHandle());
            String deviceType;
            switch (dev[0]) {
                case 0:
                    deviceType = Device.Type.CPU;
                    break;
                case 1:
                    deviceType = Device.Type.GPU;
                    break;
                case 2:
                    deviceType = "mps";
                    break;
                default:
                    throw new EngineException("Unknown device type: " + dev[0]);
            }
            device = Device.of(deviceType, dev[1]);
        }
        return device;
    }

    /** {@inheritDoc} */
    @Override
    public Shape getShape() {
        if (shape == null) {
            shape = new Shape(RustLibrary.getShape(getHandle()));
        }
        return shape;
    }

    /** {@inheritDoc} */
    @Override
    public SparseFormat getSparseFormat() {
        return SparseFormat.DENSE;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray toDevice(Device device, boolean copy) {
        if (device.equals(getDevice()) && !copy) {
            return this;
        }
        String deviceType = device.getDeviceType();
        long newHandle = RustLibrary.toDevice(getHandle(), deviceType, device.getDeviceId());
        return toArray(newHandle, null, false, true);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray toType(DataType dataType, boolean copy) {
        if (dataType.equals(getDataType()) && !copy) {
            return this;
        }
        if (dataType == DataType.BOOLEAN) {
            long newHandle = RustLibrary.toBoolean(getHandle());
            return toArray(newHandle, dataType, false, true);
        }
        if (this.dataType == DataType.INT64
                && dataType == DataType.FLOAT16
                && getDevice().isGpu()) {
            // TODO:
            throw new UnsupportedOperationException("FP16 to I64 is not supported on GPU.");
        }
        int dType = manager.toRustDataType(dataType);
        long newHandle = RustLibrary.toDataType(getHandle(), dType);
        return toArray(newHandle, dataType, false, true);
    }

    /** {@inheritDoc} */
    @Override
    public void setRequiresGradient(boolean requiresGrad) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray getGradient() {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public boolean hasGradient() {
        return false;
    }

    /** {@inheritDoc} */
    @Override
    public NDArray stopGradient() {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public ByteBuffer toByteBuffer(boolean tryDirect) {
        byte[] buf = RustLibrary.toByteArray(getHandle());
        ByteBuffer bb = ByteBuffer.wrap(buf);
        bb.order(ByteOrder.nativeOrder());
        return bb;
    }

    /** {@inheritDoc} */
    @Override
    public String[] toStringArray(Charset charset) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public void set(Buffer buffer) {
        int size = Math.toIntExact(size());
        DataType type = getDataType();
        BaseNDManager.validateBuffer(buffer, type, size);
        // TODO how do we handle the exception happened in the middle
        dataRef = null;
        if (buffer.isDirect() && buffer instanceof ByteBuffer) {
            // If NDArray is on the GPU, it is native code responsibility to control the data life
            // cycle
            if (!getDevice().isGpu()) {
                dataRef = (ByteBuffer) buffer;
            }
            intern(manager.create(buffer, getShape(), type).toDevice(getDevice(), false));
            return;
        }
        // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType
        ByteBuffer buf = manager.allocateDirect(size * type.getNumOfBytes());
        BaseNDManager.copyBuffer(buffer, buf);

        // If NDArray is on the GPU, it is native code responsibility to control the data life cycle
        if (!getDevice().isGpu()) {
            dataRef = buf;
        }
        intern(manager.create(buf, getShape(), type).toDevice(getDevice(), false));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray gather(NDArray index, int axis) {
        //        try (NDScope ignore = new NDScope()) {
        //            long indexHandle = manager.from(index).getHandle();
        //            return toArray(RustLibrary.gather(getHandle(), indexHandle, axis), true);
        //        }
        // TODO:
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray gatherNd(NDArray index) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray take(NDManager manager, NDArray index) {
        try (NDScope ignore = new NDScope()) {
            long indexHandle = this.manager.from(index).getHandle();
            long newHandle = RustLibrary.take(getHandle(), indexHandle);
            RsNDArray array = new RsNDArray((RsNDManager) manager, newHandle);
            NDScope.unregister(array);
            return array;
        }
    }

    /** {@inheritDoc} */
    @Override
    public NDArray put(NDArray index, NDArray value) {
        try (NDScope ignore = new NDScope()) {
            long indexHandle = manager.from(index).getHandle();
            long valueHandle = manager.from(value).getHandle();
            return toArray(RustLibrary.put(getHandle(), indexHandle, valueHandle), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public NDArray scatter(NDArray index, NDArray value, int axis) {
        //        try (NDScope ignore = new NDScope()) {
        //            long indexHandle = manager.from(index).getHandle();
        //            long valueHandle = manager.from(value).getHandle();
        //            return toArray(RustLibrary.scatter(getHandle(), indexHandle, valueHandle,
        // axis), true);
        //        }
        // TODO:
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public void attach(NDManager manager) {
        detach();
        this.manager = (RsNDManager) manager;
        manager.attachInternal(getUid(), this);
    }

    /** {@inheritDoc} */
    @Override
    public void returnResource(NDManager manager) {
        detach();
        this.manager = (RsNDManager) manager;
        manager.attachUncappedInternal(getUid(), this);
    }

    /** {@inheritDoc} */
    @Override
    public void tempAttach(NDManager manager) {
        NDManager original = this.manager;
        detach();
        this.manager = (RsNDManager) manager;
        manager.tempAttachInternal(original, getUid(), this);
    }

    /** {@inheritDoc} */
    @Override
    public void detach() {
        manager.detachInternal(getUid());
        manager = RsNDManager.getSystemManager();
    }

    /** {@inheritDoc} */
    @Override
    public NDArray duplicate() {
        return toArray(RustLibrary.duplicate(getHandle()), dataType, false, true);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray booleanMask(NDArray index, int axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray sequenceMask(NDArray sequenceLength, float value) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray sequenceMask(NDArray sequenceLength) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    /** {@inheritDoc} */
    @Override
    public boolean contentEquals(Number number) {
        return contentEquals(manager.create(number));
    }

    /** {@inheritDoc} */
    @Override
    public boolean contentEquals(NDArray other) {
        if (other == null || (!shapeEquals(other))) {
            return false;
        }
        if (getDataType() != other.getDataType()) {
            return false;
        }
        return RustLibrary.contentEqual(getHandle(), manager.from(other).getHandle());
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray eq(Number n) {
        try (NDArray number = manager.create(n)) {
            return eq(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray eq(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long newHandle = RustLibrary.eq(getHandle(), manager.from(other).getHandle());
            return toArray(newHandle, DataType.BOOLEAN, true, false);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray neq(Number n) {
        try (NDArray number = manager.create(n)) {
            return neq(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray neq(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long newHandle = RustLibrary.neq(getHandle(), manager.from(other).getHandle());
            return toArray(newHandle, DataType.BOOLEAN, true, false);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray gt(Number n) {
        try (NDArray number = manager.create(n)) {
            return gt(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray gt(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long newHandle = RustLibrary.gt(getHandle(), manager.from(other).getHandle());
            return toArray(newHandle, DataType.BOOLEAN, true, false);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray gte(Number n) {
        try (NDArray number = manager.create(n)) {
            return gte(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray gte(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long newHandle = RustLibrary.gte(getHandle(), manager.from(other).getHandle());
            return toArray(newHandle, DataType.BOOLEAN, true, false);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray lt(Number n) {
        try (NDArray number = manager.create(n)) {
            return lt(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray lt(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long newHandle = RustLibrary.lt(getHandle(), manager.from(other).getHandle());
            return toArray(newHandle, DataType.BOOLEAN, true, false);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray lte(Number n) {
        try (NDArray number = manager.create(n)) {
            return lte(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray lte(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long newHandle = RustLibrary.lte(getHandle(), manager.from(other).getHandle());
            return toArray(newHandle, DataType.BOOLEAN, true, false);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray add(Number n) {
        try (NDArray number = manager.create(n)) {
            return add(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray add(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.add(getHandle(), manager.from(other).getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sub(Number n) {
        try (NDArray number = manager.create(n)) {
            return sub(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sub(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.sub(getHandle(), manager.from(other).getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray mul(Number n) {
        try (NDArray number = manager.create(n)) {
            return mul(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray mul(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.mul(getHandle(), manager.from(other).getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray div(Number n) {
        try (NDArray number = manager.create(n)) {
            return div(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray div(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.div(getHandle(), manager.from(other).getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray mod(Number n) {
        try (NDArray number = manager.create(n)) {
            return mod(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray mod(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long otherHandle = manager.from(other).getHandle();
            return toArray(RustLibrary.remainder(getHandle(), otherHandle), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray pow(Number n) {
        try (NDArray number = manager.create(n)) {
            return pow(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray pow(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.pow(getHandle(), manager.from(other).getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public NDArray xlogy(NDArray other) {
        if (isScalar() || other.isScalar()) {
            throw new IllegalArgumentException("scalar is not allowed for xlogy()");
        }
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.xlogy(getHandle(), manager.from(other).getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray addi(Number n) {
        try (NDArray number = manager.create(n)) {
            return addi(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray addi(NDArray other) {
        intern(add(other));
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray subi(Number n) {
        try (NDArray number = manager.create(n)) {
            return subi(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray subi(NDArray other) {
        intern(sub(other));
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray muli(Number n) {
        try (NDArray number = manager.create(n)) {
            return muli(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray muli(NDArray other) {
        intern(mul(other));
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray divi(Number n) {
        try (NDArray number = manager.create(n)) {
            return divi(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray divi(NDArray other) {
        intern(div(other));
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray modi(Number n) {
        try (NDArray number = manager.create(n)) {
            return modi(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray modi(NDArray other) {
        intern(mod(other));
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray powi(Number n) {
        try (NDArray number = manager.create(n)) {
            return powi(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray powi(NDArray other) {
        intern(pow(other));
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray signi() {
        intern(sign());
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray negi() {
        intern(neg());
        return this;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sign() {
        return toArray(RustLibrary.sign(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray maximum(Number n) {
        try (NDArray number = manager.create(n)) {
            return maximum(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray maximum(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.maximum(getHandle(), manager.from(other).getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray minimum(Number n) {
        try (NDArray number = manager.create(n)) {
            return minimum(number);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray minimum(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.minimum(getHandle(), manager.from(other).getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray all() {
        NDArray noneZero = countNonzero();
        RsNDArray ret = (RsNDArray) manager.create(noneZero.getLong() == size());
        noneZero.close();
        return ret;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray any() {
        NDArray noneZero = countNonzero();
        RsNDArray ret = (RsNDArray) manager.create(noneZero.getLong() > 0);
        noneZero.close();
        return ret;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray none() {
        NDArray noneZero = countNonzero();
        RsNDArray ret = (RsNDArray) manager.create(noneZero.getLong() == 0);
        noneZero.close();
        return ret;
    }

    /** {@inheritDoc} */
    @Override
    public NDArray countNonzero() {
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.countNonzero(getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public NDArray countNonzero(int axis) {
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.countNonzeroWithAxis(getHandle(), axis), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray neg() {
        return toArray(RustLibrary.neg(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray abs() {
        return toArray(RustLibrary.abs(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray square() {
        return toArray(RustLibrary.square(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray sqrt() {
        return toArray(RustLibrary.sqrt(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray cbrt() {
        try (RsNDArray array = (RsNDArray) manager.create(1.0 / 3)) {
            return toArray(RustLibrary.pow(getHandle(), array.getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray floor() {
        return toArray(RustLibrary.floor(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray ceil() {
        return toArray(RustLibrary.ceil(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray round() {
        return toArray(RustLibrary.round(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray trunc() {
        return toArray(RustLibrary.trunc(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray exp() {
        return toArray(RustLibrary.exp(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray gammaln() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray log() {
        return toArray(RustLibrary.log(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray log10() {
        return toArray(RustLibrary.log10(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray log2() {
        return toArray(RustLibrary.log2(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sin() {
        return toArray(RustLibrary.sin(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray cos() {
        return toArray(RustLibrary.cos(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray tan() {
        return toArray(RustLibrary.tan(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray asin() {
        return toArray(RustLibrary.asin(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray acos() {
        return toArray(RustLibrary.acos(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray atan() {
        return toArray(RustLibrary.atan(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray atan2(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long otherHandle = manager.from(other).getHandle();
            return toArray(RustLibrary.atan2(getHandle(), otherHandle), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sinh() {
        return toArray(RustLibrary.sinh(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray cosh() {
        return toArray(RustLibrary.cosh(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray tanh() {
        return toArray(RustLibrary.tanh(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray asinh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray acosh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray atanh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray toDegrees() {
        return mul(180.0).div(Math.PI);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray toRadians() {
        return mul(Math.PI).div(180.0);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray max() {
        if (isScalar()) {
            return this;
        }
        return toArray(RustLibrary.max(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray max(int[] axes, boolean keepDims) {
        if (axes.length > 1) {
            // TODO fix this
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return toArray(RustLibrary.maxWithAxis(getHandle(), axes[0], keepDims));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray min() {
        if (isScalar()) {
            return this;
        }
        return toArray(RustLibrary.min(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray min(int[] axes, boolean keepDims) {
        if (axes.length > 1) {
            // TODO fix this
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return toArray(RustLibrary.minWithAxis(getHandle(), axes[0], keepDims));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sum() {
        if (isScalar()) {
            return this;
        }
        return toArray(RustLibrary.sum(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sum(int[] axes, boolean keepDims) {
        return toArray(RustLibrary.sumWithAxis(getHandle(), axes, keepDims));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray cumProd(int axis) {
        return toArray(RustLibrary.cumProd(getHandle(), axis));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray cumProd(int axis, DataType dataType) {
        return toArray(RustLibrary.cumProdWithType(getHandle(), axis, dataType.ordinal()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray prod() {
        return toArray(RustLibrary.prod(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray prod(int[] axes, boolean keepDims) {
        if (axes.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return toArray(RustLibrary.cumProdWithAxis(getHandle(), axes[0], keepDims));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray mean() {
        return toArray(RustLibrary.mean(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray mean(int[] axes, boolean keepDims) {
        return toArray(RustLibrary.meanWithAxis(getHandle(), axes, keepDims));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray normalize(double p, long dim, double eps) {
        return toArray(RustLibrary.normalize(getHandle(), p, dim, eps));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rotate90(int times, int[] axes) {
        if (axes.length != 2) {
            throw new IllegalArgumentException("Axes must be 2");
        }
        return toArray(RustLibrary.rot90(getHandle(), times, axes));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray trace(int offset, int axis1, int axis2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList split(long[] indices, int axis) {
        if (indices.length == 0) {
            return new NDList(this);
        }
        long lastIndex = getShape().get(axis);
        if (indices[indices.length - 1] != lastIndex) {
            long[] tmp = new long[indices.length + 1];
            System.arraycopy(indices, 0, tmp, 0, indices.length);
            tmp[indices.length] = lastIndex;
            indices = tmp;
        }
        return toList(RustLibrary.split(getHandle(), indices, axis));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray flatten() {
        return toArray(RustLibrary.flatten(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray flatten(int startDim, int endDim) {
        return toArray(RustLibrary.flattenWithDims(getHandle(), startDim, endDim));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray fft(long length, long axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray rfft(long length, long axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray ifft(long length, long axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray irfft(long length, long axis) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray stft(
            long nFft,
            long hopLength,
            boolean center,
            NDArray window,
            boolean normalize,
            boolean returnComplex) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray fft2(long[] sizes, long[] axes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray pad(Shape padding, double value) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray ifft2(long[] sizes, long[] axes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray reshape(Shape shape) {
        long prod = 1;
        int neg = -1;
        long[] dims = shape.getShape();
        for (int i = 0; i < dims.length; ++i) {
            if (dims[i] < 0) {
                if (neg != -1) {
                    throw new IllegalArgumentException("only 1 negative axis is allowed");
                }
                neg = i;
            } else {
                prod *= dims[i];
            }
        }
        if (neg != -1) {
            long total = getShape().size();
            if (total % prod != 0) {
                throw new IllegalArgumentException("unsupported dimensions");
            }
            dims[neg] = total / prod;
        }
        return toArray(RustLibrary.reshape(getHandle(), shape.getShape()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray expandDims(int axis) {
        return toArray(RustLibrary.expandDims(getHandle(), axis));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray squeeze(int[] axes) {
        return toArray(RustLibrary.squeeze(getHandle(), axes));
    }

    /** {@inheritDoc} */
    @Override
    public NDList unique(Integer dim, boolean sorted, boolean returnInverse, boolean returnCounts) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray logicalAnd(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long otherHandle = manager.from(other).getHandle();
            return toArray(RustLibrary.logicalAnd(getHandle(), otherHandle), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray logicalOr(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long otherHandle = manager.from(other).getHandle();
            return toArray(RustLibrary.logicalOr(getHandle(), otherHandle), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray logicalXor(NDArray other) {
        try (NDScope ignore = new NDScope()) {
            long otherHandle = manager.from(other).getHandle();
            return toArray(RustLibrary.logicalXor(getHandle(), otherHandle), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray logicalNot() {
        return toArray(RustLibrary.logicalNot(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray argSort(int axis, boolean ascending) {
        return toArray(RustLibrary.argSort(getHandle(), axis, ascending));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sort() {
        return sort(-1);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sort(int axis) {
        return toArray(RustLibrary.sort(getHandle(), axis, false));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray softmax(int axis) {
        if (getShape().isScalar() || shape.size() == 0) {
            return (RsNDArray) duplicate();
        }
        return toArray(RustLibrary.softmax(getHandle(), axis));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray logSoftmax(int axis) {
        return toArray(RustLibrary.logSoftmax(getHandle(), axis));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray cumSum() {
        // TODO: change default behavior on cumSum
        if (isScalar()) {
            return (RsNDArray) reshape(1);
        }
        if (isEmpty()) {
            return (RsNDArray) reshape(0);
        }
        return cumSum(0);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray cumSum(int axis) {
        if (getShape().dimension() > 3) {
            throw new UnsupportedOperationException("Only 3 dimensions or less is supported");
        }
        return toArray(RustLibrary.cumSum(getHandle(), axis));
    }

    /** {@inheritDoc} */
    @Override
    public void intern(NDArray replaced) {
        RsNDArray arr = (RsNDArray) replaced;
        Long oldHandle = handle.getAndSet(arr.handle.getAndSet(null));
        RustLibrary.deleteTensor(oldHandle);
        // dereference old ndarray
        arr.close();
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray isInfinite() {
        return toArray(RustLibrary.isInf(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray isNaN() {
        return toArray(RustLibrary.isNaN(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray tile(long repeats) {
        // zero-dim
        if (isEmpty()) {
            return (RsNDArray) duplicate();
        }
        // scalar
        int dim = (isScalar()) ? 1 : getShape().dimension();
        long[] repeatsArray = new long[dim];
        Arrays.fill(repeatsArray, repeats);
        return tile(repeatsArray);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray tile(int axis, long repeat) {
        return toArray(RustLibrary.tileWithAxis(getHandle(), axis, repeat));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray tile(long[] repeats) {
        return toArray(RustLibrary.tile(getHandle(), repeats));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray tile(Shape desiredShape) {
        return toArray(RustLibrary.tileWithShape(getHandle(), desiredShape.getShape()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray repeat(long repeats) {
        // zero-dim
        if (isEmpty()) {
            return (RsNDArray) duplicate();
        }
        // scalar
        int dim = (isScalar()) ? 1 : getShape().dimension();
        long[] repeatsArray = new long[dim];
        Arrays.fill(repeatsArray, repeats);
        return repeat(repeatsArray);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray repeat(int axis, long repeat) {
        return toArray(RustLibrary.repeat(getHandle(), repeat, axis));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray repeat(long[] repeats) {
        RsNDArray result = this;
        for (int dim = 0; dim < repeats.length; dim++) {
            RsNDArray temp = result;
            result = result.repeat(dim, repeats[dim]);
            if (temp != this) {
                temp.close();
            }
        }
        return result;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray repeat(Shape desiredShape) {
        return repeat(repeatsToMatchShape(desiredShape));
    }

    private long[] repeatsToMatchShape(Shape desiredShape) {
        Shape curShape = getShape();
        int dimension = curShape.dimension();
        if (desiredShape.dimension() > dimension) {
            throw new IllegalArgumentException("The desired shape has too many dimensions");
        }
        if (desiredShape.dimension() < dimension) {
            int additionalDimensions = dimension - desiredShape.dimension();
            desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape);
        }
        long[] repeats = new long[dimension];
        for (int i = 0; i < dimension; i++) {
            if (curShape.get(i) == 0 || desiredShape.get(i) % curShape.get(i) != 0) {
                throw new IllegalArgumentException(
                        "The desired shape is not a multiple of the original shape");
            }
            repeats[i] = Math.round(Math.ceil((double) desiredShape.get(i) / curShape.get(i)));
        }
        return repeats;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray dot(NDArray other) {
        int selfDim = this.getShape().dimension();
        int otherDim = other.getShape().dimension();
        if (selfDim != otherDim || selfDim > 2) {
            throw new UnsupportedOperationException(
                    "Dimension mismatch or dimension is greater than 2.  Dot product is only"
                            + " applied on two 1D vectors. For high dimensions, please use .matMul"
                            + " instead.");
        }
        try (NDScope ignore = new NDScope()) {
            return toArray(RustLibrary.dot(getHandle(), manager.from(other).getHandle()), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public NDArray matMul(NDArray other) {
        if (getShape().dimension() < 2 || getShape().dimension() < 2) {
            throw new IllegalArgumentException("only 2d tensors are supported for matMul()");
        }
        try (NDScope ignore = new NDScope()) {
            long otherHandle = manager.from(other).getHandle();
            return toArray(RustLibrary.matmul(getHandle(), otherHandle), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public NDArray batchMatMul(NDArray other) {
        if (getShape().dimension() != 3 || getShape().dimension() != 3) {
            throw new IllegalArgumentException("only 3d tensors are allowed for batchMatMul()");
        }
        try (NDScope ignore = new NDScope()) {
            long otherHandle = manager.from(other).getHandle();
            return toArray(RustLibrary.batchMatMul(getHandle(), otherHandle), true);
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray clip(Number min, Number max) {
        return toArray(RustLibrary.clip(getHandle(), min.doubleValue(), max.doubleValue()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray swapAxes(int axis1, int axis2) {
        return toArray(RustLibrary.transpose(getHandle(), axis1, axis2));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray flip(int... axes) {
        return toArray(RustLibrary.flip(getHandle(), axes));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray transpose() {
        int dim = getShape().dimension();
        int[] reversedShape = IntStream.range(0, dim).map(i -> dim - i - 1).toArray();
        return transpose(reversedShape);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray transpose(int... axes) {
        if (isScalar() && axes.length > 0) {
            throw new IllegalArgumentException("axes don't match NDArray");
        }
        return toArray(RustLibrary.permute(getHandle(), axes));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray broadcast(Shape shape) {
        return toArray(RustLibrary.broadcast(getHandle(), shape.getShape()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray argMax() {
        if (isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMax of an empty NDArray");
        }
        if (isScalar()) {
            return (RsNDArray) manager.create(0L);
        }
        return toArray(RustLibrary.argMax(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray argMax(int axis) {
        if (isScalar()) {
            return (RsNDArray) manager.create(0L);
        }
        return toArray(RustLibrary.argMaxWithAxis(getHandle(), axis, false));
    }

    /** {@inheritDoc} */
    @Override
    public NDList topK(int k, int axis, boolean largest, boolean sorted) {
        return toList(RustLibrary.topK(getHandle(), k, axis, largest, sorted));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray argMin() {
        if (isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        if (isScalar()) {
            return (RsNDArray) manager.create(0L);
        }
        return toArray(RustLibrary.argMin(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray argMin(int axis) {
        if (isScalar()) {
            return (RsNDArray) manager.create(0L);
        }
        return toArray(RustLibrary.argMinWithAxis(getHandle(), axis, false));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray percentile(Number percentile) {
        return toArray(RustLibrary.percentile(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray percentile(Number percentile, int[] axes) {
        return toArray(RustLibrary.percentileWithAxes(getHandle(), percentile.doubleValue(), axes));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray median() {
        return median(new int[] {-1});
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray median(int[] axes) {
        if (axes.length != 1) {
            throw new UnsupportedOperationException(
                    "Not supporting zero or multi-dimension median");
        }
        NDList result = toList(RustLibrary.median(getHandle(), axes[0], false));
        result.get(1).close();
        return (RsNDArray) result.get(0);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray toDense() {
        return (RsNDArray) duplicate();
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray toSparse(SparseFormat fmt) {
        throw new UnsupportedOperationException("Not supported");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray nonzero() {
        return toArray(RustLibrary.nonZero(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray erfinv() {
        return toArray(RustLibrary.erfinv(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray erf() {
        return toArray(RustLibrary.erf(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray inverse() {
        return toArray(RustLibrary.inverse(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray norm(boolean keepDims) {
        return toArray(RustLibrary.norm(getHandle(), 2, new int[] {}, keepDims));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray norm(int order, int[] axes, boolean keepDims) {
        return toArray(RustLibrary.norm(getHandle(), order, axes, keepDims));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) {
        return toArray(
                RustLibrary.oneHot(getHandle(), depth, onValue, offValue, dataType.ordinal()));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray batchDot(NDArray other) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray complex() {
        return toArray(RustLibrary.complex(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray real() {
        return toArray(RustLibrary.real(getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public NDArray conj() {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArrayEx getNDArrayInternal() {
        if (ndArrayEx == null) {
            throw new UnsupportedOperationException(
                    "NDArray operation is not supported for String tensor");
        }
        return ndArrayEx;
    }

    /** {@inheritDoc} */
    @Override
    public String toString() {
        if (isReleased()) {
            return "This array is already closed";
        }
        return toDebugString();
    }

    /** {@inheritDoc} */
    @Override
    public boolean equals(Object obj) {
        if (obj instanceof NDArray) {
            return contentEquals((NDArray) obj);
        }
        return false;
    }

    /** {@inheritDoc} */
    @Override
    public int hashCode() {
        return 0;
    }

    /** {@inheritDoc} */
    @Override
    public void close() {
        onClose();
        Long pointer = handle.getAndSet(null);
        if (pointer != null && pointer != -1) {
            RustLibrary.deleteTensor(pointer);
        }
        manager.detachInternal(getUid());
        dataRef = null;
    }

    private RsNDArray toArray(long newHandle) {
        return toArray(newHandle, false);
    }

    private RsNDArray toArray(long newHandle, boolean unregister) {
        return toArray(newHandle, null, unregister, false);
    }

    private RsNDArray toArray(
            long newHandle, DataType dataType, boolean unregister, boolean withName) {
        RsNDArray array = new RsNDArray(manager, newHandle, dataType);
        if (withName) {
            array.setName(getName());
        }
        if (unregister) {
            NDScope.unregister(array);
        }
        return array;
    }

    private NDList toList(long[] handles) {
        NDList list = new NDList(handles.length);
        for (long h : handles) {
            list.add(new RsNDArray(manager, h));
        }
        return list;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy