All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.engine.rust.RsNDArrayEx Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.engine.rust;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.recurrent.RNN;

import java.util.List;

/** {@code PtNDArrayEx} is the Rust implementation of the {@link NDArrayEx}. */
@SuppressWarnings("try")
public class RsNDArrayEx implements NDArrayEx {

    private RsNDArray array;

    /**
     * Constructs an {@code PtNDArrayEx} given a {@link NDArray}.
     *
     * @param parent the {@link NDArray} to extend
     */
    RsNDArrayEx(RsNDArray parent) {
        this.array = parent;
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rdiv(Number n) {
        return rdiv(array.getManager().create(n));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rdiv(NDArray b) {
        return (RsNDArray) b.div(array);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rdivi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rdivi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rsub(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rsub(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rsubi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rsubi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rmod(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rmod(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rmodi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rmodi(NDArray b) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rpow(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray rpowi(Number n) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray relu() {
        return new RsNDArray(array.getManager(), RustLibrary.relu(array.getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray sigmoid() {
        return new RsNDArray(array.getManager(), RustLibrary.sigmoid(array.getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray tanh() {
        return array.tanh();
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray softPlus() {
        return new RsNDArray(array.getManager(), RustLibrary.softPlus(array.getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray softSign() {
        return new RsNDArray(array.getManager(), RustLibrary.softSign(array.getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray leakyRelu(float alpha) {
        return new RsNDArray(array.getManager(), RustLibrary.leakyRelu(array.getHandle(), alpha));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray elu(float alpha) {
        return new RsNDArray(array.getManager(), RustLibrary.elu(array.getHandle(), alpha));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray selu() {
        return new RsNDArray(array.getManager(), RustLibrary.selu(array.getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray gelu() {
        return new RsNDArray(array.getManager(), RustLibrary.gelu(array.getHandle()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray maxPool(Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        return new RsNDArray(
                array.getManager(),
                RustLibrary.maxPool(
                        array.getHandle(),
                        kernelShape.getShape(),
                        stride.getShape(),
                        padding.getShape(),
                        ceilMode));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray globalMaxPool() {
        Shape shape = getPoolShape(array);
        long newHandle = RustLibrary.adaptiveMaxPool(array.getHandle(), shape.getShape());
        try (NDArray temp = new RsNDArray(array.getManager(), newHandle)) {
            return (RsNDArray) temp.reshape(array.getShape().slice(0, 2));
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray avgPool(
            Shape kernelShape,
            Shape stride,
            Shape padding,
            boolean ceilMode,
            boolean countIncludePad) {
        if (kernelShape.size() != 2) {
            throw new UnsupportedOperationException("Only avgPool2d is supported");
        }
        return new RsNDArray(
                array.getManager(),
                RustLibrary.avgPool2d(
                        array.getHandle(), kernelShape.getShape(), stride.getShape()));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray globalAvgPool() {
        Shape shape = getPoolShape(array);
        long newHandle = RustLibrary.adaptiveAvgPool(array.getHandle(), shape.getShape());
        try (NDArray temp = new RsNDArray(array.getManager(), newHandle)) {
            return (RsNDArray) temp.reshape(array.getShape().slice(0, 2));
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray lpPool(
            float normType, Shape kernelShape, Shape stride, Shape padding, boolean ceilMode) {
        if (padding.size() != 0) {
            throw new IllegalArgumentException("padding is not supported for Rust engine");
        }
        return new RsNDArray(
                array.getManager(),
                RustLibrary.lpPool(
                        array.getHandle(),
                        normType,
                        kernelShape.getShape(),
                        stride.getShape(),
                        ceilMode));
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray globalLpPool(float normType) {
        long[] kernelShape = array.getShape().slice(2).getShape();
        long[] stride = getPoolShape(array).getShape();
        long newHandle =
                RustLibrary.lpPool(array.getHandle(), normType, kernelShape, stride, false);
        try (NDArray temp = new RsNDArray(array.getManager(), newHandle)) {
            return (RsNDArray) temp.reshape(array.getShape().slice(0, 2));
        }
    }

    /** {@inheritDoc} */
    @Override
    public void adadeltaUpdate(
            NDList inputs,
            NDList weights,
            float weightDecay,
            float rescaleGrad,
            float clipGrad,
            float rho,
            float epsilon) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public void adagradUpdate(
            NDList inputs,
            NDList weights,
            float learningRate,
            float weightDecay,
            float rescaleGrad,
            float clipGrad,
            float epsilon) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public void adamUpdate(
            NDList inputs,
            NDList weights,
            float learningRate,
            float learningRateBiasCorrection,
            float weightDecay,
            float rescaleGrad,
            float clipGrad,
            float beta1,
            float beta2,
            float epsilon,
            boolean lazyUpdate,
            boolean adamw) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public void nagUpdate(
            NDList inputs,
            NDList weights,
            float learningRate,
            float weightDecay,
            float rescaleGrad,
            float clipGrad,
            float momentum) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public void rmspropUpdate(
            NDList inputs,
            NDList weights,
            float learningRate,
            float weightDecay,
            float rescaleGrad,
            float clipGrad,
            float rho,
            float momentum,
            float epsilon,
            boolean centered) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public void sgdUpdate(
            NDList inputs,
            NDList weights,
            float learningRate,
            float weightDecay,
            float rescaleGrad,
            float clipGrad,
            float momentum,
            boolean lazyUpdate) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList convolution(
            NDArray input,
            NDArray weight,
            NDArray bias,
            Shape stride,
            Shape padding,
            Shape dilation,
            int groups) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList deconvolution(
            NDArray input,
            NDArray weight,
            NDArray bias,
            Shape stride,
            Shape padding,
            Shape outPadding,
            Shape dilation,
            int groups) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList linear(NDArray input, NDArray weight, NDArray bias) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList embedding(NDArray input, NDArray weight, SparseFormat sparseFormat) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList prelu(NDArray input, NDArray alpha) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList dropout(NDArray input, float rate, boolean training) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList layerNorm(
            NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList batchNorm(
            NDArray input,
            NDArray runningMean,
            NDArray runningVar,
            NDArray gamma,
            NDArray beta,
            int axis,
            float momentum,
            float eps,
            boolean training) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList rnn(
            NDArray input,
            NDArray state,
            NDList params,
            boolean hasBiases,
            int numLayers,
            RNN.Activation activation,
            double dropRate,
            boolean training,
            boolean bidirectional,
            boolean batchFirst) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList gru(
            NDArray input,
            NDArray state,
            NDList params,
            boolean hasBiases,
            int numLayers,
            double dropRate,
            boolean training,
            boolean bidirectional,
            boolean batchFirst) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList lstm(
            NDArray input,
            NDList states,
            NDList params,
            boolean hasBiases,
            int numLayers,
            double dropRate,
            boolean training,
            boolean bidirectional,
            boolean batchFirst) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray interpolation(long[] size, int mode, boolean alignCorners) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray resize(int width, int height, int interpolation) {
        long[] shape = array.getShape().getShape();
        if (shape[0] == height && shape[1] == width) {
            return array.toType(DataType.FLOAT32, false);
        }
        // TODO:
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomFlipLeftRight() {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomFlipTopBottom() {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomBrightness(float brightness) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomHue(float hue) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomColorJitter(
            float brightness, float contrast, float saturation, float hue) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDArrayIndexer getIndexer(NDManager manager) {
        return new RsNDArrayIndexer((RsNDManager) manager);
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray where(NDArray condition, NDArray other) {
        // Try to broadcast if shape mismatch
        if (!condition.getShape().equals(array.getShape())) {
            throw new UnsupportedOperationException(
                    "condition and self shape mismatch, broadcast is not supported");
        }
        RsNDManager manager = array.getManager();
        try (NDScope ignore = new NDScope()) {
            long conditionHandle = manager.from(condition).getHandle();
            long otherHandle = manager.from(other).getHandle();
            RsNDArray ret =
                    new RsNDArray(
                            manager,
                            RustLibrary.where(conditionHandle, array.getHandle(), otherHandle));
            NDScope.unregister(ret);
            return ret;
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray stack(NDList arrays, int axis) {
        long[] srcArray = new long[arrays.size() + 1];
        srcArray[0] = array.getHandle();
        RsNDManager manager = array.getManager();

        try (NDScope ignore = new NDScope()) {
            int i = 1;
            for (NDArray arr : arrays) {
                srcArray[i++] = manager.from(arr).getHandle();
            }
            RsNDArray ret = new RsNDArray(manager, RustLibrary.stack(srcArray, axis));
            NDScope.unregister(ret);
            return ret;
        }
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray concat(NDList list, int axis) {
        NDUtils.checkConcatInput(list);

        long[] srcArray = new long[list.size() + 1];
        srcArray[0] = array.getHandle();
        RsNDManager manager = array.getManager();
        try (NDScope ignore = new NDScope()) {
            int i = 1;
            for (NDArray arr : list) {
                srcArray[i++] = manager.from(arr).getHandle();
            }
            RsNDArray ret = new RsNDArray(manager, RustLibrary.concat(srcArray, axis));
            NDScope.unregister(ret);
            return ret;
        }
    }

    /** {@inheritDoc} */
    @Override
    public NDList multiBoxTarget(
            NDList inputs,
            float iouThreshold,
            float ignoreLabel,
            float negativeMiningRatio,
            float negativeMiningThreshold,
            int minNegativeSamples) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList multiBoxPrior(
            List sizes,
            List ratios,
            List steps,
            List offsets,
            boolean clip) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public NDList multiBoxDetection(
            NDList inputs,
            boolean clip,
            float threshold,
            int backgroundId,
            float nmsThreshold,
            boolean forceSuppress,
            int nmsTopK) {
        throw new UnsupportedOperationException("Not implemented");
    }

    /** {@inheritDoc} */
    @Override
    public RsNDArray getArray() {
        return array;
    }

    private Shape getPoolShape(NDArray array) {
        switch (array.getShape().dimension() - 2) {
            case 1:
                return new Shape(1);
            case 2:
                return new Shape(1, 1);
            case 3:
                return new Shape(1, 1, 1);
            default:
                throw new IllegalArgumentException("the input dimension should be in [3, 5]");
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy