All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.pytorch.jni.JniUtils Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.pytorch.jni;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexNull;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.recurrent.RNN;
import ai.djl.pytorch.engine.PtDeviceType;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;

/**
 * A class containing utilities to interact with the PyTorch Engine's Java Native Interface (JNI)
 * layer.
 */
@SuppressWarnings("MissingJavadocMethod")
public final class JniUtils {

    private static final Logger logger = LoggerFactory.getLogger(JniUtils.class);

    private static Set configs;

    private static final int NULL_PTR = 0;

    private static final int BYTE_LENGTH = 4194304;

    private JniUtils() {}

    private static int layoutMapper(SparseFormat fmt, Device device) {
        if (fmt == SparseFormat.DENSE) {
            // Enable MKLDNN with environment variable
            // Using MKLDNN with GPU would throw exception on libtorch
            if (Boolean.getBoolean("ai.djl.pytorch.use_mkldnn") && !device.equals(Device.gpu())) {
                return 2;
            }
            return 0;
        } else if (fmt == SparseFormat.COO) {
            return 1;
        } else {
            throw new IllegalArgumentException(
                    "Current PyTorch only support SparseFormat.DENSE and SparseFormat.COO");
        }
    }

    public static boolean isGradMode() {
        return PyTorchLibrary.LIB.torchIsGradMode();
    }

    public static void setGradMode(boolean enable) {
        PyTorchLibrary.LIB.torchSetGradMode(enable);
    }

    public static int getNumInteropThreads() {
        return PyTorchLibrary.LIB.torchGetNumInteropThreads();
    }

    public static int getNumThreads() {
        return PyTorchLibrary.LIB.torchGetNumThreads();
    }

    public static void setNumInteropThreads(int threads) {
        PyTorchLibrary.LIB.torchSetNumInteropThreads(threads);
    }

    public static void setNumThreads(int threads) {
        PyTorchLibrary.LIB.torchSetNumThreads(threads);
    }

    public static void setBenchmarkCuDNN(boolean enable) {
        PyTorchLibrary.LIB.torchSetBenchmarkCuDNN(enable);
    }

    public static synchronized Set getFeatures() {
        if (configs != null) {
            return configs;
        }
        Set features = new HashSet<>();
        PyTorchLibrary.LIB.torchShowConfig(features);
        configs = features;
        return configs;
    }

    public static void setSeed(long seed) {
        PyTorchLibrary.LIB.torchManualSeed(seed);
    }

    /**
     * Calls this method to start profile the area you are interested in.
     *
     * 

Example usage * *

     *      JniUtils.startProfile(false, true, true);
     *      Predictor.predict(img);
     *      JniUtils.stopProfile(outputFile)
     * 
* * @param useCuda Enables timing of CUDA events as well using the cudaEvent API. * @param recordShape If shapes recording is set, information about input dimensions will be * collected * @param profileMemory Whether to report memory usage */ public static synchronized void startProfile( boolean useCuda, boolean recordShape, boolean profileMemory) { PyTorchLibrary.LIB.torchStartProfile(useCuda, recordShape, profileMemory); } public static synchronized void stopProfile(String outputFile) { PyTorchLibrary.LIB.torchStopProfile(outputFile); } // TODO: Unchecked Datatype and device mapping public static PtNDArray createNdFromByteBuffer( PtNDManager manager, ByteBuffer data, Shape shape, DataType dType, SparseFormat fmt, Device device) { int layout = layoutMapper(fmt, device); long handle = PyTorchLibrary.LIB.torchFromBlob( data, shape.getShape(), dType.ordinal(), layout, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false); if (layout == 1 || layout == 2 || device.isGpu()) { // MKLDNN & COO & GPU device will explicitly make a copy in native code // so we don't want to hold a reference on Java side return new PtNDArray(manager, handle); } return new PtNDArray(manager, handle, data); } public static void emptyCudaCache() { PyTorchLibrary.LIB.torchCudaEmptyCache(); } public static PtNDArray createEmptyNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return new PtNDArray( manager, PyTorchLibrary.LIB.torchEmpty( shape.getShape(), dType.ordinal(), layoutVal, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray createZerosNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return new PtNDArray( manager, PyTorchLibrary.LIB.torchZeros( shape.getShape(), dType.ordinal(), layoutVal, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray createOnesNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return new PtNDArray( manager, PyTorchLibrary.LIB.torchOnes( shape.getShape(), dType.ordinal(), layoutVal, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray full( PtNDManager manager, Shape shape, double fillValue, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return new PtNDArray( manager, PyTorchLibrary.LIB.torchFull( shape.getShape(), fillValue, dType.ordinal(), layoutVal, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray zerosLike( PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return new PtNDArray( array.getManager(), PyTorchLibrary.LIB.torchZerosLike( array.getHandle(), dType.ordinal(), layoutVal, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray onesLike( PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return new PtNDArray( array.getManager(), PyTorchLibrary.LIB.torchOnesLike( array.getHandle(), dType.ordinal(), layoutVal, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray arange( PtNDManager manager, float start, float stop, float step, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return new PtNDArray( manager, PyTorchLibrary.LIB.torchArange( start, stop, step, dType.ordinal(), layoutVal, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray linspace( PtNDManager manager, float start, float stop, int step, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return new PtNDArray( manager, PyTorchLibrary.LIB.torchLinspace( start, stop, step, dType.ordinal(), layoutVal, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray createSparseCoo(PtNDArray indices, PtNDArray values, Shape shape) { return new PtNDArray( values.getManager(), PyTorchLibrary.LIB.torchSparseCoo( shape.getShape(), indices.getHandle(), values.getHandle(), false)); } public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) { PtNDManager manager = ndArray.getManager(); // the device of the manager should always match the one in NDArray which the manager attach // to if (!device.equals(manager.getDevice())) { manager = manager.newSubManager(device); } return new PtNDArray( manager, PyTorchLibrary.LIB.torchTo( ndArray.getHandle(), dataType.ordinal(), new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()})); } public static PtNDArray toSparse(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchToSparse(ndArray.getHandle())); } public static PtNDArray toDense(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchToDense(ndArray.getHandle())); } public static PtNDArray broadcast(PtNDArray ndArray, Shape shape) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchExpand(ndArray.getHandle(), shape.getShape())); } public static PtNDArray slice(PtNDArray ndArray, long dim, long start, long stop, long step) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSlice(ndArray.getHandle(), dim, start, stop, step)); } public static PtNDArray index( PtNDArray ndArray, long[] minIndices, long[] maxIndices, long[] stepIndices, PtNDManager manager) { return new PtNDArray( manager, PyTorchLibrary.LIB.torchIndex( ndArray.getHandle(), minIndices, maxIndices, stepIndices)); } @SuppressWarnings("OptionalGetWithoutIsPresent") public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager manager) { if (ndArray == null) { return ndArray; } List indices = index.getIndices(); long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); try { // Index aggregation ListIterator it = indices.listIterator(); while (it.hasNext()) { if (it.nextIndex() == index.getEllipsisIndex()) { PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } NDIndexElement elem = it.next(); if (elem instanceof NDIndexNull) { PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false); } else if (elem instanceof NDIndexSlice) { Long min = ((NDIndexSlice) elem).getMin(); Long max = ((NDIndexSlice) elem).getMax(); Long step = ((NDIndexSlice) elem).getStep(); int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); // nullSliceBinary encodes whether the slice end {min, max} is null: // is_null == 1, ! is_null == 0; // 0b11 == 3, 0b10 = 2, ... // If {min, max} is null, then its value is ineffective, thus set to -1. PyTorchLibrary.LIB.torchIndexAppendSlice( torchIndexHandle, min == null ? -1 : min, max == null ? -1 : max, step == null ? 1 : step, nullSliceBinary); } else if (elem instanceof NDIndexAll) { PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3); } else if (elem instanceof NDIndexFixed) { PyTorchLibrary.LIB.torchIndexAppendFixed( torchIndexHandle, ((NDIndexFixed) elem).getIndex()); } else if (elem instanceof NDIndexBooleans) { PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); PyTorchLibrary.LIB.torchIndexAppendArray( torchIndexHandle, indexArr.getHandle()); } else if (elem instanceof NDIndexTake) { PtNDArray indexArr = manager.from(((NDIndexTake) elem).getIndex()); if (indexArr.getDataType() != DataType.INT64) { indexArr = indexArr.toType(DataType.INT64, true); } PyTorchLibrary.LIB.torchIndexAppendArray( torchIndexHandle, indexArr.getHandle()); } else if (elem instanceof NDIndexPick) { // Backward compatible NDIndexFullPick fullPick = NDIndexFullPick.fromIndex(index, ndArray.getShape()).get(); return pick(ndArray, manager.from(fullPick.getIndices()), fullPick.getAxis()); } } if (indices.size() == index.getEllipsisIndex()) { PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } long ret = PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle); return new PtNDArray(manager, ret); } finally { PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); } } @SuppressWarnings("OptionalGetWithoutIsPresent") public static void indexAdvPut(PtNDArray ndArray, NDIndex index, PtNDArray data) { if (ndArray == null) { return; } List indices = index.getIndices(); long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); try { // Index aggregation ListIterator it = indices.listIterator(); while (it.hasNext()) { if (it.nextIndex() == index.getEllipsisIndex()) { PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } NDIndexElement elem = it.next(); if (elem instanceof NDIndexNull) { PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false); } else if (elem instanceof NDIndexSlice) { Long min = ((NDIndexSlice) elem).getMin(); Long max = ((NDIndexSlice) elem).getMax(); Long step = ((NDIndexSlice) elem).getStep(); int nullSliceBinary = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); // nullSliceBinary encodes whether the slice end {min, max} is null: // is_null == 1, ! is_null == 0; // 0b11 == 3, 0b10 = 2, ... // If {min, max} is null, then its value is ineffective, thus set to -1. PyTorchLibrary.LIB.torchIndexAppendSlice( torchIndexHandle, min == null ? -1 : min, max == null ? -1 : max, step == null ? 1 : step, nullSliceBinary); } else if (elem instanceof NDIndexAll) { PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, -1, -1, 1, 3); } else if (elem instanceof NDIndexFixed) { PyTorchLibrary.LIB.torchIndexAppendFixed( torchIndexHandle, ((NDIndexFixed) elem).getIndex()); } else if (elem instanceof NDIndexBooleans) { PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); PyTorchLibrary.LIB.torchIndexAppendArray( torchIndexHandle, indexArr.getHandle()); } else if (elem instanceof NDIndexTake) { PtNDArray indexArr = (PtNDArray) ((NDIndexTake) elem).getIndex(); if (indexArr.getDataType() != DataType.INT64) { indexArr = indexArr.toType(DataType.INT64, true); } PyTorchLibrary.LIB.torchIndexAppendArray( torchIndexHandle, indexArr.getHandle()); } else if (elem instanceof NDIndexPick) { // Backward compatible NDIndexFullPick fullPick = NDIndexFullPick.fromIndex(index, ndArray.getShape()).get(); pick( ndArray, ndArray.getManager().from(fullPick.getIndices()), fullPick.getAxis()); return; } } if (indices.size() == index.getEllipsisIndex()) { PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } PyTorchLibrary.LIB.torchIndexAdvPut( ndArray.getHandle(), torchIndexHandle, data.getHandle()); } finally { PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); } } public static void indexSet( PtNDArray ndArray, PtNDArray value, long[] minIndices, long[] maxIndices, long[] stepIndices) { PyTorchLibrary.LIB.torchIndexPut( ndArray.getHandle(), value.getHandle(), minIndices, maxIndices, stepIndices); } public static void set(PtNDArray self, ByteBuffer data) { // Note the ByteBuffer here is directByteBuffer PyTorchLibrary.LIB.torchSet(self.getHandle(), data); } public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } public static PtNDArray take(PtNDArray ndArray, PtNDArray index, PtNDManager manager) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } return new PtNDArray( manager, PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); } public static PtNDArray put(PtNDArray ndArray, PtNDArray index, PtNDArray value) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchPut( ndArray.getHandle(), index.getHandle(), value.getHandle())); } public static PtNDArray scatter(PtNDArray ndArray, PtNDArray index, PtNDArray value, int axis) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchScatter( ndArray.getHandle(), index.getHandle(), value.getHandle(), axis)); } public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { Shape indexShape = index.getShape(); Shape ndShape = ndArray.getShape(); int shapeDims = indexShape.dimension(); int ndDims = ndShape.dimension(); if (shapeDims != ndDims) { for (int i = 0; i < ndDims - shapeDims; ++i) { if (indexShape.equals(ndShape.slice(i, shapeDims))) { long[] shapes = indexShape.getShape(); long[] newShape = new long[ndDims]; Arrays.fill(newShape, 0, i, 1L); Arrays.fill(newShape, i, i + shapes.length, shapes[i]); Arrays.fill(newShape, i + shapes.length, ndDims, 1L); indexShape = new Shape(newShape); break; } } if (indexShape.equals(index.getShape())) { throw new IllegalArgumentException( "expand shape failed! Cannot expand from " + indexShape + "to " + ndShape); } index = index.reshape(indexShape); } if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } public static PtNDArray where(PtNDArray condition, PtNDArray self, PtNDArray other) { return new PtNDArray( self.getManager(), PyTorchLibrary.LIB.torchWhere( condition.getHandle(), self.getHandle(), other.getHandle())); } public static PtNDArray booleanMask(PtNDArray ndArray, PtNDArray indicesNd) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMaskedSelect(ndArray.getHandle(), indicesNd.getHandle())); } public static void booleanMaskSet(PtNDArray ndArray, PtNDArray value, PtNDArray indicesNd) { PyTorchLibrary.LIB.torchMaskedPut( ndArray.getHandle(), value.getHandle(), indicesNd.getHandle()); } public static PtNDArray getItem(PtNDArray ndArray, long[] indices, PtNDManager manager) { // use a specialized API here // due to significant performance gain // for commonly used data loading call if (indices.length == 1) { return new PtNDArray( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices[0])); } return new PtNDArray( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices)); } public static PtNDArray clone(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.tensorClone(ndArray.getHandle())); } public static PtNDArray pad(PtNDArray ndArray, long[] shape, double value) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchPad(ndArray.getHandle(), shape, value)); } public static PtNDArray reshape(PtNDArray ndArray, long[] shape) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchReshape(ndArray.getHandle(), shape)); } public static PtNDArray stack(PtNDArray[] arrays, int dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); return new PtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim)); } public static PtNDArray cat(PtNDArray[] arrays, long dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); return new PtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim)); } public static PtNDArray tile(PtNDArray ndArray, long[] repeats) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeat(ndArray.getHandle(), repeats)); } public static PtNDArray repeat(PtNDArray ndArray, long repeat, long dim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeatInterleave(ndArray.getHandle(), repeat, dim)); } public static PtNDArray softmax(PtNDArray ndArray, long dim, DataType dTpe) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray logSoftmax(PtNDArray ndArray, long dim, DataType dTpe) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLogSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray argMax(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle())); } public static PtNDArray argMax(PtNDArray ndArray, long dim, boolean keepDim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle(), dim, keepDim)); } public static NDList topK( PtNDArray ndArray, long k, long axis, boolean largest, boolean sorted) { long[] handles = PyTorchLibrary.LIB.torchTopK(ndArray.getHandle(), k, axis, largest, sorted); NDList list = new NDList(handles.length); for (long handle : handles) { PtNDArray array = new PtNDArray(ndArray.getManager(), handle); list.add(array); } return list; } public static PtNDArray argMin(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle())); } public static PtNDArray argMin(PtNDArray ndArray, long dim, boolean keepDim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray argSort(PtNDArray ndArray, long dim, boolean keepDim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgSort(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray sort(PtNDArray ndArray, long dim, boolean descending) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSort(ndArray.getHandle(), dim, descending)); } public static PtNDArray permute(PtNDArray ndArray, long[] dims) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchPermute(ndArray.getHandle(), dims)); } public static PtNDArray flip(PtNDArray ndArray, long[] dims) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFlip(ndArray.getHandle(), dims)); } public static PtNDArray transpose(PtNDArray ndArray, long dim1, long dim2) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTranspose(ndArray.getHandle(), dim1, dim2)); } public static boolean contentEqual(PtNDArray ndArray1, PtNDArray ndArray2) { return PyTorchLibrary.LIB.contentEqual(ndArray1.getHandle(), ndArray2.getHandle()); } public static PtNDArray add(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchAdd(ndArray1.getHandle(), ndArray2.getHandle())); } public static void addi(PtNDArray ndArray1, PtNDArray ndArray2) { PyTorchLibrary.LIB.torchAddi(ndArray1.getHandle(), ndArray2.getHandle()); } public static PtNDArray sub(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchSub(ndArray1.getHandle(), ndArray2.getHandle())); } public static void subi(PtNDArray ndArray1, PtNDArray ndArray2) { PyTorchLibrary.LIB.torchSubi(ndArray1.getHandle(), ndArray2.getHandle()); } public static PtNDArray mul(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMul(ndArray1.getHandle(), ndArray2.getHandle())); } public static void muli(PtNDArray ndArray1, PtNDArray ndArray2) { PyTorchLibrary.LIB.torchMuli(ndArray1.getHandle(), ndArray2.getHandle()); } public static PtNDArray div(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchTrueDivide(ndArray1.getHandle(), ndArray2.getHandle())); } public static void divi(PtNDArray ndArray1, PtNDArray ndArray2) { PyTorchLibrary.LIB.torchTrueDividei(ndArray1.getHandle(), ndArray2.getHandle()); } public static PtNDArray remainder(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchRemainder(ndArray1.getHandle(), ndArray2.getHandle())); } public static void remainderi(PtNDArray ndArray1, PtNDArray ndArray2) { PyTorchLibrary.LIB.torchRemainderi(ndArray1.getHandle(), ndArray2.getHandle()); } public static PtNDArray pow(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchPow(ndArray1.getHandle(), ndArray2.getHandle())); } public static void powi(PtNDArray ndArray1, PtNDArray ndArray2) { PyTorchLibrary.LIB.torchPowi(ndArray1.getHandle(), ndArray2.getHandle()); } public static PtNDArray sign(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSign(ndArray.getHandle())); } public static void signi(PtNDArray ndArray) { PyTorchLibrary.LIB.torchSigni(ndArray.getHandle()); } public static PtNDArray logicalAnd(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalAnd(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalOr(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalOr(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalXor(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalXor(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalNot(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLogicalNot(ndArray.getHandle())); } public static PtNDArray matmul(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray bmm(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchBmm(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray xlogy(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchXLogY(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray dot(PtNDArray ndArray1, PtNDArray ndArray2) { if (ndArray1.getShape().dimension() == 1) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchDot(ndArray1.getHandle(), ndArray2.getHandle())); } return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMaximum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle())); } public static PtNDArray max(PtNDArray ndArray, long dim, boolean keepDim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray min(PtNDArray ndArray1, PtNDArray ndArray2) { return new PtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMinimum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray min(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle())); } public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle(), dim, keepDim)); } public static NDList median(PtNDArray ndArray, long dim, boolean keepDim) { long[] handles = PyTorchLibrary.LIB.torchMedian(ndArray.getHandle(), dim, keepDim); return new NDList( new PtNDArray(ndArray.getManager(), handles[0]), new PtNDArray(ndArray.getManager(), handles[1])); } public static PtNDArray mean(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle())); } public static PtNDArray mean(PtNDArray ndArray, long dim, boolean keepDim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray rot90(PtNDArray ndArray, int times, int[] axes) { long[] longaxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRot90(ndArray.getHandle(), times, longaxes)); } public static PtNDArray sum(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle())); } public static PtNDArray sum(PtNDArray ndArray, long[] dims, boolean keepDim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle(), dims, keepDim)); } public static PtNDArray cumProd(PtNDArray ndArray, long dim, DataType dataType) { int dtPosition = -1; if (dataType != null) { dtPosition = dataType.ordinal(); } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCumProd(ndArray.getHandle(), dim, dtPosition)); } public static PtNDArray prod(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle())); } public static PtNDArray prod(PtNDArray ndArray, long dim, boolean keepDim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray cumSum(PtNDArray ndArray, long dim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim)); } public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNOneHot( ndArray.toType(DataType.INT64, false).getHandle(), depth)) .toType(dataType, false); } public static NDList split(PtNDArray ndArray, long size, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), size, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { list.add(new PtNDArray(ndArray.getManager(), ptr)); } return list; } public static NDList split(PtNDArray ndArray, long[] indices, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), indices, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { list.add(new PtNDArray(ndArray.getManager(), ptr)); } return list; } public static PtNDArray squeeze(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle())); } public static PtNDArray squeeze(PtNDArray ndArray, long dim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle(), dim)); } public static PtNDArray unsqueeze(PtNDArray ndArray, long dim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchUnsqueeze(ndArray.getHandle(), dim)); } public static NDList unique( PtNDArray ndArray, Integer dim, boolean sorted, boolean returnInverse, boolean returnCounts) { long[] handles; if (dim == null) { // In this case the output will be flattened. handles = PyTorchLibrary.LIB.torchUnique( ndArray.getHandle(), -1, sorted, returnInverse, returnCounts); } else { // Dimension wrap dim = Math.floorMod(dim, ndArray.getShape().dimension()); handles = PyTorchLibrary.LIB.torchUnique( ndArray.getHandle(), dim, sorted, returnInverse, returnCounts); } NDList list = new NDList(handles.length); for (long handle : handles) { PtNDArray array = new PtNDArray(ndArray.getManager(), handle); list.add(array); } return list; } public static PtNDArray flatten(PtNDArray ndArray, long startDim, long endDim) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFlatten(ndArray.getHandle(), startDim, endDim)); } public static PtNDArray fft(PtNDArray ndArray, long length, long axis) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFft(ndArray.getHandle(), length, axis)); } public static PtNDArray ifft(PtNDArray ndArray, long length, long axis) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIfft(ndArray.getHandle(), length, axis)); } public static PtNDArray rfft(PtNDArray ndArray, long length, long axis) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRfft(ndArray.getHandle(), length, axis)); } public static PtNDArray irfft(PtNDArray ndArray, long length, long axis) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIrfft(ndArray.getHandle(), length, axis)); } public static PtNDArray stft( PtNDArray ndArray, long nFft, long hopLength, PtNDArray window, boolean center, boolean normalize, boolean returnComplex) { long handle = PyTorchLibrary.LIB.torchStft( ndArray.getHandle(), nFft, hopLength, window.getHandle(), center, normalize, returnComplex); if (handle == -1) { throw new UnsupportedOperationException("real() is not supported."); } return new PtNDArray(ndArray.getManager(), handle); } public static PtNDArray fft2(PtNDArray ndArray, long[] sizes, long[] axes) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFft2(ndArray.getHandle(), sizes, axes)); } public static PtNDArray ifft2(PtNDArray ndArray, long[] sizes, long[] axes) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIfft2(ndArray.getHandle(), sizes, axes)); } public static PtNDArray real(PtNDArray ndArray) { long handle = PyTorchLibrary.LIB.torchViewAsReal(ndArray.getHandle()); if (handle == -1) { throw new UnsupportedOperationException("real() is not supported."); } return new PtNDArray(ndArray.getManager(), handle); } public static PtNDArray complex(PtNDArray ndArray) { long handle = PyTorchLibrary.LIB.torchViewAsComplex(ndArray.getHandle()); if (handle == -1) { throw new UnsupportedOperationException("complex() is not supported."); } return new PtNDArray(ndArray.getManager(), handle); } public static PtNDArray conj(PtNDArray ndArray) { return new PtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.conj(ndArray.getHandle())); } public static PtNDArray abs(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAbs(ndArray.getHandle())); } public static PtNDArray square(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSquare(ndArray.getHandle())); } public static PtNDArray floor(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFloor(ndArray.getHandle())); } public static PtNDArray ceil(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCeil(ndArray.getHandle())); } public static PtNDArray round(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRound(ndArray.getHandle())); } public static PtNDArray trunc(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTrunc(ndArray.getHandle())); } public static PtNDArray clip(PtNDArray ndArray, Number min, Number max) { PtNDArray minNd = (PtNDArray) ndArray.getManager().create(min); PtNDArray maxNd = (PtNDArray) ndArray.getManager().create(max); return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchClamp( ndArray.getHandle(), minNd.getHandle(), maxNd.getHandle())); } public static PtNDArray exp(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchExp(ndArray.getHandle())); } public static PtNDArray gammaln(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLgamma(ndArray.getHandle())); } public static PtNDArray log(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog(ndArray.getHandle())); } public static PtNDArray log10(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog10(ndArray.getHandle())); } public static PtNDArray log2(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog2(ndArray.getHandle())); } public static PtNDArray sin(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSin(ndArray.getHandle())); } public static PtNDArray cos(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCos(ndArray.getHandle())); } public static PtNDArray tan(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTan(ndArray.getHandle())); } public static PtNDArray asin(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchASin(ndArray.getHandle())); } public static PtNDArray acos(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAcos(ndArray.getHandle())); } public static PtNDArray atan(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle())); } public static PtNDArray atan2(PtNDArray self, PtNDArray other) { return new PtNDArray( self.getManager(), PyTorchLibrary.LIB.torchAtan2(self.getHandle(), other.getHandle())); } public static PtNDArray sqrt(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle())); } public static PtNDArray sinh(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSinh(ndArray.getHandle())); } public static PtNDArray cosh(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCosh(ndArray.getHandle())); } public static PtNDArray tanh(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTanh(ndArray.getHandle())); } public static PtNDArray sigmoid(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSigmoid(ndArray.getHandle())); } public static PtNDArray all(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAll(ndArray.getHandle())); } public static PtNDArray any(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAny(ndArray.getHandle())); } public static PtNDArray none(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNone(ndArray.getHandle())); } public static PtNDArray eq(PtNDArray self, PtNDArray other) { return new PtNDArray( self.getManager(), PyTorchLibrary.LIB.torchEq(self.getHandle(), other.getHandle())); } public static PtNDArray neq(PtNDArray self, PtNDArray other) { return new PtNDArray( self.getManager(), PyTorchLibrary.LIB.torchNeq(self.getHandle(), other.getHandle())); } public static PtNDArray gt(PtNDArray self, PtNDArray other) { return new PtNDArray( self.getManager(), PyTorchLibrary.LIB.torchGt(self.getHandle(), other.getHandle())); } public static PtNDArray gte(PtNDArray self, PtNDArray other) { return new PtNDArray( self.getManager(), PyTorchLibrary.LIB.torchGte(self.getHandle(), other.getHandle())); } public static PtNDArray lt(PtNDArray self, PtNDArray other) { return new PtNDArray( self.getManager(), PyTorchLibrary.LIB.torchLt(self.getHandle(), other.getHandle())); } public static PtNDArray lte(PtNDArray self, PtNDArray other) { return new PtNDArray( self.getManager(), PyTorchLibrary.LIB.torchLte(self.getHandle(), other.getHandle())); } public static PtNDArray neg(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNeg(ndArray.getHandle())); } public static void negi(PtNDArray ndArray) { PyTorchLibrary.LIB.torchNegi(ndArray.getHandle()); } public static PtNDArray isNaN(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIsNaN(ndArray.getHandle())); } public static PtNDArray isInf(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIsInf(ndArray.getHandle())); } public static PtNDArray randint( PtNDManager manager, long low, long high, Shape size, DataType dataType, Device device) { return new PtNDArray( manager, PyTorchLibrary.LIB.torchRandint( low, high, size.getShape(), dataType.ordinal(), layoutMapper(SparseFormat.DENSE, device), new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray randperm( PtNDManager manager, long n, DataType dataType, Device device) { return new PtNDArray( manager, PyTorchLibrary.LIB.torchRandPerm( n, dataType.ordinal(), layoutMapper(SparseFormat.DENSE, device), new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray normal( PtNDManager manager, double mean, double std, Shape size, DataType dataType, Device device) { return new PtNDArray( manager, PyTorchLibrary.LIB.torchNormal( mean, std, size.getShape(), dataType.ordinal(), layoutMapper(SparseFormat.DENSE, device), new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray uniform( PtNDManager manager, double low, double high, Shape size, DataType dataType, Device device) { return new PtNDArray( manager, PyTorchLibrary.LIB.tensorUniform( low, high, size.getShape(), dataType.ordinal(), layoutMapper(SparseFormat.DENSE, device), new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray eye( PtNDManager manager, int n, int m, DataType dataType, Device device, SparseFormat fmt) { return new PtNDArray( manager, PyTorchLibrary.LIB.torchEye( n, m, dataType.ordinal(), layoutMapper(fmt, device), new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, false)); } public static PtNDArray hannWindow( PtNDManager manager, long numPoints, boolean periodic, Device device) { return new PtNDArray( manager, PyTorchLibrary.LIB.torchHannWindow( numPoints, periodic, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()})); } public static PtNDArray erfinv(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle())); } public static PtNDArray erf(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchErf(ndArray.getHandle())); } public static PtNDArray inverse(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle())); } public static PtNDArray interpolate( PtNDArray ndArray, long[] size, int mode, boolean alignCorners) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNInterpolate( ndArray.getHandle(), size, mode, alignCorners)); } public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias) { return new PtNDArray( input.getManager(), PyTorchLibrary.LIB.torchNNLinear( input.getHandle(), weight.getHandle(), bias == null ? NULL_PTR : bias.getHandle())); } public static PtNDArray embedding(PtNDArray input, PtNDArray weight, boolean sparse) { return new PtNDArray( input.getManager(), PyTorchLibrary.LIB.torchNNEmbedding(input.getHandle(), weight.getHandle(), sparse)); } public static PtNDArray relu(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNRelu(ndArray.getHandle())); } public static PtNDArray softPlus(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftPlus(ndArray.getHandle())); } public static PtNDArray softSign(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftSign(ndArray.getHandle())); } public static PtNDArray leakyRelu(PtNDArray ndArray, double negativeSlope) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLeakyRelu(ndArray.getHandle(), negativeSlope)); } public static PtNDArray elu(PtNDArray ndArray, double alpha) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNElu(ndArray.getHandle(), alpha)); } public static PtNDArray selu(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSelu(ndArray.getHandle())); } public static PtNDArray gelu(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNGelu(ndArray.getHandle())); } public static PtNDArray convolution( PtNDArray ndArray, PtNDArray weight, PtNDArray bias, Shape stride, Shape padding, Shape dilation, int groups) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNConvNd( ndArray.getHandle(), weight.getHandle(), (bias != null) ? bias.getHandle() : NULL_PTR, stride.getShape(), padding.getShape(), dilation.getShape(), groups)); } public static PtNDArray batchNorm( PtNDArray ndArray, PtNDArray gamma, PtNDArray beta, PtNDArray runningMean, PtNDArray runningVar, boolean isTraining, double momentum, double eps) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNBatchNorm( ndArray.getHandle(), gamma.getHandle(), beta.getHandle(), runningMean.getHandle(), runningVar.getHandle(), isTraining, momentum, eps)); } public static PtNDArray layerNorm( PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLayerNorm( ndArray.getHandle(), normalizedShape.getShape(), gamma.getHandle(), beta.getHandle(), eps)); } public static PtNDArray normalize(PtNDArray ndArray, double p, long dim, double eps) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNNormalize(ndArray.getHandle(), p, dim, eps)); } public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNDropout(ndArray.getHandle(), prob, training)); } public static NDList rnn( PtNDArray input, PtNDArray hx, NDList params, boolean hasBiases, int numLayers, RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) { PtNDManager manager = input.getManager(); long[] paramHandles = params.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray(); long[] outputs = PyTorchLibrary.LIB.torchNNRnn( input.getHandle(), hx.getHandle(), paramHandles, hasBiases, numLayers, activation.ordinal(), dropRate, training, bidirectional, batchFirst); NDList res = new NDList(); for (long output : outputs) { res.add(new PtNDArray(manager, output)); } return res; } public static NDList gru( PtNDArray input, PtNDArray hx, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) { PtNDManager manager = input.getManager(); long[] paramHandles = params.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray(); long[] outputs = PyTorchLibrary.LIB.torchNNGru( input.getHandle(), hx.getHandle(), paramHandles, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst); NDList res = new NDList(); for (long output : outputs) { res.add(new PtNDArray(manager, output)); } return res; } public static NDList lstm( PtNDArray input, NDList hx, NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst) { PtNDManager manager = input.getManager(); long[] hxHandles = hx.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray(); long[] paramHandles = params.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray(); long[] outputs = PyTorchLibrary.LIB.torchNNLstm( input.getHandle(), hxHandles, paramHandles, hasBiases, numLayers, dropRate, training, bidirectional, batchFirst); NDList res = new NDList(); for (long output : outputs) { res.add(new PtNDArray(manager, output)); } return res; } public static PtNDArray avgPool( PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode, boolean countIncludePad) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAvgPool( ndArray.getHandle(), kernelSize.getShape(), stride.getShape(), padding.getShape(), ceilMode, countIncludePad)); } public static PtNDArray maxPool( PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNMaxPool( ndArray.getHandle(), kernelSize.getShape(), stride.getShape(), padding.getShape(), ceilMode)); } public static PtNDArray adaptiveMaxPool(PtNDArray ndArray, Shape outputSize) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveMaxPool( ndArray.getHandle(), outputSize.getShape())); } public static PtNDArray adaptiveAvgPool(PtNDArray ndArray, Shape outputSize) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveAvgPool( ndArray.getHandle(), outputSize.getShape())); } public static PtNDArray lpPool( PtNDArray ndArray, double normType, Shape kernelSize, Shape stride, boolean ceilMode) { if (ndArray.getShape().dimension() - 2 == 3) { throw new UnsupportedOperationException("3D lpPool is not supported in PyTorch engine"); } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLpPool( ndArray.getHandle(), normType, kernelSize.getShape(), stride.getShape(), ceilMode)); } public static DataType getDataType(PtNDArray ndArray) { int dataType = PyTorchLibrary.LIB.torchDType(ndArray.getHandle()); return DataType.values()[dataType]; } public static Device getDevice(PtNDArray ndArray) { int[] device = PyTorchLibrary.LIB.torchDevice(ndArray.getHandle()); String deviceType = PtDeviceType.fromDeviceType(device[0]); return Device.of(deviceType, device[1]); } public static SparseFormat getSparseFormat(PtNDArray ndArray) { int layout = PyTorchLibrary.LIB.torchLayout(ndArray.getHandle()); if (layout == 0) { return SparseFormat.DENSE; } else if (layout == 1) { return SparseFormat.COO; } else if (layout == 2) { logger.debug("MKLDNN layout is used!"); return SparseFormat.DENSE; } throw new UnsupportedOperationException("Unsupported data format"); } public static Shape getShape(PtNDArray ndArray) { return new Shape(PyTorchLibrary.LIB.torchSizes(ndArray.getHandle())); } public static ByteBuffer getByteBuffer(PtNDArray ndArray, boolean tryDirect) { // Operation is CPU only if (!ndArray.getDevice().equals(Device.cpu())) { ndArray = ndArray.toDevice(Device.cpu(), false); } if (tryDirect) { if (ndArray.isSparse() || getLayout(ndArray) == 2 || !PyTorchLibrary.LIB.torchIsContiguous(ndArray.getHandle())) { // keep the same lifecycle as origin NDArray ndArray = new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchToContiguous(ndArray.getHandle())); } return PyTorchLibrary.LIB .torchDirectByteBuffer(ndArray.getHandle()) .order(ByteOrder.nativeOrder()); } return ByteBuffer.wrap(PyTorchLibrary.LIB.torchDataPtr(ndArray.getHandle())) .order(ByteOrder.nativeOrder()); } public static void deleteNDArray(long handle) { PyTorchLibrary.LIB.torchDeleteTensor(handle); } public static boolean requiresGrad(PtNDArray ndArray) { return PyTorchLibrary.LIB.torchRequiresGrad(ndArray.getHandle()); } public static String getGradientFunctionNames(PtNDArray ndArray) { return PyTorchLibrary.LIB.torchGradFnName(ndArray.getHandle()); } public static void attachGradient(PtNDArray ndArray, boolean requiresGrad) { PyTorchLibrary.LIB.torchAttachGrad(ndArray.getHandle(), requiresGrad); } public static PtNDArray detachGradient(PtNDArray ndArray) { // TODO: detached ndarray may not use the same manager for the attached one return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchDetachGrad(ndArray.getHandle())); } public static PtNDArray getGradient(PtNDArray ndArray) { long pointer = PyTorchLibrary.LIB.torchGrad(ndArray.getHandle()); if (pointer == NULL_PTR) { return null; } return new PtNDArray(ndArray.getManager(), pointer); } public static void backward( PtNDArray ndArray, PtNDArray gradNd, boolean keepGraph, boolean createGraph) { PyTorchLibrary.LIB.torchBackward( ndArray.getHandle(), gradNd.getHandle(), keepGraph, createGraph); } public static void deleteModule(long pointer) { PyTorchLibrary.LIB.torchDeleteModule(pointer); } public static void setGraphExecutorOptimize(boolean enabled) { PyTorchLibrary.LIB.setGraphExecutorOptimize(enabled); } public static PtSymbolBlock loadModule( PtNDManager manager, Path path, boolean mapLocation, String[] extraFileKeys, String[] extraFileValues, boolean trainParam) { Device device = manager.getDevice(); // MPS doesn't support mapLocation if ("mps".equals(device.getDeviceType())) { mapLocation = false; } logger.debug("mapLocation: {}", mapLocation); logger.debug("extraFileKeys: {}", Arrays.toString(extraFileKeys)); long handle = PyTorchLibrary.LIB.moduleLoad( path.toString(), new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, mapLocation, extraFileKeys, extraFileValues, trainParam); return new PtSymbolBlock(manager, handle); } public static PtSymbolBlock loadModule( PtNDManager manager, InputStream is, boolean mapLocation, boolean hasSize) throws IOException { long handle = loadModuleHandle(is, manager.getDevice(), mapLocation, hasSize); return new PtSymbolBlock(manager, handle); } public static long loadModuleHandle( InputStream is, Device device, boolean mapLocation, boolean hasSize) throws IOException { byte[] buf = new byte[BYTE_LENGTH]; long size = -1; if (hasSize) { size = new DataInputStream(is).readLong(); } // MPS doesn't support mapLocation if ("mps".equals(device.getDeviceType())) { mapLocation = false; } logger.debug("mapLocation: {}", mapLocation); return PyTorchLibrary.LIB.moduleLoad( is, new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, mapLocation, buf, size); } public static void writeModule(PtSymbolBlock block, OutputStream os, boolean writeSize) { byte[] buf = new byte[BYTE_LENGTH]; PyTorchLibrary.LIB.moduleWrite(block.getHandle(), os, buf, writeSize); } public static NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager) { long[] handles = PyTorchLibrary.LIB.moduleGetParams(block.getHandle()); String[] names = PyTorchLibrary.LIB.moduleGetParamNames(block.getHandle()); NDList list = new NDList(handles.length); for (int i = 0; i < handles.length; i++) { PtNDArray array = new PtNDArray(manager, handles[i]); array.setName(names[i]); list.add(array); } return list; } public static String[] getMethodNames(PtSymbolBlock block) { return PyTorchLibrary.LIB.moduleGetMethodNames(block.getHandle()); } public static void enableInferenceMode(PtSymbolBlock block) { PyTorchLibrary.LIB.moduleEval(block.getHandle()); } public static void enableTrainingMode(PtSymbolBlock block) { PyTorchLibrary.LIB.moduleTrain(block.getHandle()); } public static void zeroGrad(PtNDArray weight) { PyTorchLibrary.LIB.zeroGrad(weight.getHandle()); } public static void adamUpdate( PtNDArray weight, PtNDArray grad, PtNDArray mean, PtNDArray variance, float lr, float learningRateBiasCorrection, float wd, float rescaleGrad, float clipGrad, float beta1, float beta2, float eps, boolean adamw) { PyTorchLibrary.LIB.adamUpdate( weight.getHandle(), grad.getHandle(), mean.getHandle(), variance.getHandle(), lr, learningRateBiasCorrection, wd, rescaleGrad, clipGrad, beta1, beta2, eps, adamw); } public static void sgdUpdate( PtNDArray weight, PtNDArray grad, PtNDArray state, float lr, float wd, float rescaleGrad, float clipGrad, float momentum) { PyTorchLibrary.LIB.sgdUpdate( weight.getHandle(), grad.getHandle(), (state == null) ? NULL_PTR : state.getHandle(), lr, wd, rescaleGrad, clipGrad, momentum); } // Internal use only public static int getLayout(PtNDArray array) { return PyTorchLibrary.LIB.torchLayout(array.getHandle()); } public static PtNDArray norm(PtNDArray ndArray, int ord, int[] axes, boolean keepDims) { long[] longAxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNorm(ndArray.getHandle(), ord, longAxes, keepDims)); } public static PtNDArray nonZeros(PtNDArray ndArray) { if (ndArray.isScalar()) { ndArray = (PtNDArray) ndArray.reshape(-1); } return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNonZeros(ndArray.getHandle())); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy