All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.mxnet.jna.JnaUtils Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.mxnet.jna;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.mxnet.engine.CachedOp;
import ai.djl.mxnet.engine.MxDeviceType;
import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.engine.MxSymbolBlock;
import ai.djl.mxnet.engine.Symbol;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.Parameter;
import ai.djl.util.PairList;
import ai.djl.util.Utils;

import com.sun.jna.Native;
import com.sun.jna.Pointer;
import com.sun.jna.ptr.PointerByReference;

import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * A class containing utilities to interact with the MXNet Engine's Java Native Access (JNA) layer.
 */
@SuppressWarnings("MissingJavadocMethod")
public final class JnaUtils {

    public static final ObjectPool REFS =
            new ObjectPool<>(PointerByReference::new, r -> r.setValue(null));

    /** An enum that enumerates the statuses of numpy mode. */
    public enum NumpyMode {
        OFF,
        THREAD_LOCAL_ON,
        GLOBAL_ON
    }

    private static final String[] OP_NAME_PREFIX = {
        "_contrib_", "_linalg_", "_sparse_", "_image_", "_random_"
    };

    private static final MxnetLibrary LIB = LibUtils.loadLibrary();

    private static final Map OPS = getNdArrayFunctions();
    private static final Set FEATURES = getFeaturesInternal();

    private JnaUtils() {}

    /////////////////////////////////
    // MXNet information
    /////////////////////////////////

    public static int getVersion() {
        IntBuffer version = IntBuffer.allocate(1);
        checkCall(LIB.MXGetVersion(version));

        return version.get();
    }

    public static Set getAllOpNames() {
        IntBuffer outSize = IntBuffer.allocate(1);
        PointerByReference outArray = REFS.acquire();

        checkCall(LIB.MXListAllOpNames(outSize, outArray));

        int size = outSize.get();
        Pointer[] pointers = outArray.getValue().getPointerArray(0, size);

        Set set = new HashSet<>();
        for (Pointer p : pointers) {
            set.add(p.getString(0, StandardCharsets.UTF_8.name()));
        }
        REFS.recycle(outArray);
        return set;
    }

    public static Map getNdArrayFunctions() {
        Set opNames = JnaUtils.getAllOpNames();
        Map map = new ConcurrentHashMap<>();

        PointerByReference ref = REFS.acquire();
        for (String opName : opNames) {
            checkCall(LIB.NNGetOpHandle(opName, ref));

            String functionName = getOpNamePrefix(opName);

            // System.out.println("Name: " + opName + "/" + functionName);
            map.put(functionName, getFunctionByName(opName, functionName, ref.getValue()));
            ref.setValue(null);
        }
        REFS.recycle(ref);
        return map;
    }

    public static FunctionInfo op(String opName) {
        if (!OPS.containsKey(opName)) {
            throw new IllegalArgumentException("Unknown operator: " + opName);
        }
        return OPS.get(opName);
    }

    private static FunctionInfo getFunctionByName(
            String name, String functionName, Pointer handle) {
        String[] nameRef = {name};
        String[] description = new String[1];
        IntBuffer numArgs = IntBuffer.allocate(1);
        PointerByReference argNameRef = REFS.acquire();
        PointerByReference argTypeRef = REFS.acquire();
        PointerByReference argDescRef = REFS.acquire();
        String[] keyVarArgs = new String[1];
        String[] returnType = new String[1];

        checkCall(
                LIB.MXSymbolGetAtomicSymbolInfo(
                        handle,
                        nameRef,
                        description,
                        numArgs,
                        argNameRef,
                        argTypeRef,
                        argDescRef,
                        keyVarArgs,
                        returnType));

        int count = numArgs.get();
        PairList arguments = new PairList<>();
        if (count != 0) {
            String[] argNames =
                    argNameRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
            String[] argTypes =
                    argTypeRef.getValue().getStringArray(0, count, StandardCharsets.UTF_8.name());
            for (int i = 0; i < argNames.length; i++) {
                arguments.add(argNames[i], argTypes[i]);
            }
        }

        REFS.recycle(argNameRef);
        REFS.recycle(argTypeRef);
        REFS.recycle(argDescRef);

        return new FunctionInfo(handle, functionName, arguments);
    }

    /*
    int MXFuncGetInfo(Pointer fun, String name[], String description[], IntBuffer num_args,
                      PointerByReference arg_names, PointerByReference arg_type_infos,
                      PointerByReference arg_descriptions, String return_type[]);

    int MXFuncDescribe(Pointer fun, IntBuffer num_use_vars, IntBuffer num_scalars,
                       IntBuffer num_mutate_vars, IntBuffer type_mask);

    int MXFuncInvoke(Pointer fun, PointerByReference use_vars, FloatBuffer scalar_args,
                     PointerByReference mutate_vars);

    int MXFuncInvokeEx(Pointer fun, PointerByReference use_vars, FloatBuffer scalar_args,
                       PointerByReference mutate_vars, int num_params,
                       PointerByReference param_keys, PointerByReference param_vals);
    */

    /////////////////////////////////
    // System information
    /////////////////////////////////

    public static int getGpuCount() {
        IntBuffer count = IntBuffer.allocate(1);
        checkCall(LIB.MXGetGPUCount(count));

        return count.get();
    }

    public static long[] getGpuMemory(Device device) {
        if (!device.isGpu()) {
            throw new IllegalArgumentException("Only GPU device is allowed.");
        }

        int deviceId = device.getDeviceId();
        long[] ret = new long[2];

        LongBuffer freeMem = LongBuffer.wrap(ret, 0, 1);
        LongBuffer totalMem = LongBuffer.wrap(ret, 1, 1);

        checkCall(LIB.MXGetGPUMemoryInformation64(deviceId, freeMem, totalMem));

        return ret;
    }

    /* Need tests
    public static void setOmpThreads(int threads) {
        checkCall(LIB.MXSetNumOMPThreads(threads));
    }

    public static int setBulkSize(int bulkSize) {
        IntBuffer prevBulkSize = IntBuffer.allocate(1);
        checkCall(LIB.MXEngineSetBulkSize(bulkSize, prevBulkSize));

        return prevBulkSize.get();
    }
    */

    /////////////////////////////////
    // Utilities
    /////////////////////////////////

    public static Set getFeatures() {
        return FEATURES;
    }

    private static Set getFeaturesInternal() {
        PointerByReference ref = REFS.acquire();
        NativeSizeByReference outSize = new NativeSizeByReference();
        checkCall(LIB.MXLibInfoFeatures(ref, outSize));

        int size = outSize.getValue().intValue();
        if (size == 0) {
            REFS.recycle(ref);
            return Collections.emptySet();
        }

        LibFeature pointer = new LibFeature(ref.getValue());
        pointer.read();

        LibFeature[] features = (LibFeature[]) pointer.toArray(size);

        Set set = new HashSet<>();
        for (LibFeature feature : features) {
            if (feature.getEnabled() == 1) {
                set.add(feature.getName());
            }
        }
        REFS.recycle(ref);
        return set;
    }

    public static int randomSeed(int seed) {
        return LIB.MXRandomSeed(seed);
    }

    /* Need tests

    public static int randomSeed(int seed, Device device) {
        int deviceType = DeviceType.toDeviceType(device);
        return LIB.MXRandomSeedContext(seed, deviceType, device.getDeviceId());
    }

    public static void notifyShutdown() {
        checkCall(LIB.MXNotifyShutdown());
    }
    */

    /////////////////////////////////
    // Profiler information
    /////////////////////////////////

    /*
    public static int setProcessProfilerConfig(int numParams, String keys[], String vals[],
                                               Pointer kvstoreHandle) {

    }

    int MXSetProfilerConfig(int num_params, String keys[], String vals[]);

    int MXSetProcessProfilerState(int state, int profile_process, Pointer kvStoreHandle);

    int MXSetProfilerState(int state);

    int MXDumpProcessProfile(int finished, int profile_process, Pointer kvStoreHandle);

    int MXDumpProfile(int finished);

    int MXAggregateProfileStatsPrint(String out_str[], int reset);

    int MXProcessProfilePause(int paused, int profile_process, Pointer kvStoreHandle);

    int MXProfilePause(int paused);

    int MXProfileCreateDomain(String domain, PointerByReference out);

    int MXProfileCreateTask(Pointer domain, Pointer task_name, PointerByReference out);

    int MXProfileCreateTask(Pointer domain, String task_name, PointerByReference out);

    int MXProfileCreateFrame(Pointer domain, String frame_name, PointerByReference out);

    int MXProfileCreateEvent(String event_name, PointerByReference out);

    int MXProfileCreateCounter(Pointer domain, String counter_name, PointerByReference out);

    int MXProfileDestroyHandle(Pointer frame_handle);

    int MXProfileDurationStart(Pointer duration_handle);

    int MXProfileDurationStop(Pointer duration_handle);

    int MXProfileSetCounter(Pointer counter_handle, long value);

    int MXProfileAdjustCounter(Pointer counter_handle, long value);

    int MXProfileSetMarker(Pointer domain, String instant_marker_name, String scope);
    */

    /////////////////////////////////
    // NDArray
    /////////////////////////////////

    /* Need tests
    public static Pointer createNdArray() {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXNDArrayCreateNone(ref));

        return ref.getValue();
    }
     */

    public static Pointer createNdArray(
            Device device, Shape shape, DataType dtype, int size, boolean delayedAlloc) {
        int deviceType = MxDeviceType.toDeviceType(device);
        int deviceId = device.getDeviceId();
        int delay = delayedAlloc ? 1 : 0;

        PointerByReference ref = REFS.acquire();
        long[] shapeArray = shape.getShape();
        checkCall(
                LIB.MXNDArrayCreateEx64(
                        shapeArray, size, deviceType, deviceId, delay, dtype.ordinal(), ref));

        Pointer pointer = ref.getValue();
        REFS.recycle(ref);
        return pointer;
    }

    public static Pointer createSparseNdArray(
            SparseFormat fmt,
            Device device,
            Shape shape,
            DataType dtype,
            DataType[] auxDTypes,
            Shape[] auxShapes,
            boolean delayedAlloc) {
        long[] shapeArray = shape.getShape();
        int deviceType = MxDeviceType.toDeviceType(device);
        int deviceId = device.getDeviceId();
        int delay = delayedAlloc ? 1 : 0;
        PointerByReference ref = REFS.acquire();
        IntBuffer auxDTypesInt =
                IntBuffer.wrap(Arrays.stream(auxDTypes).mapToInt(DataType::ordinal).toArray());
        IntBuffer auxNDims =
                IntBuffer.wrap(Arrays.stream(auxShapes).mapToInt(Shape::dimension).toArray());
        long[] auxShapesInt = Arrays.stream(auxShapes).mapToLong(Shape::head).toArray();
        checkCall(
                LIB.MXNDArrayCreateSparseEx64(
                        fmt.getValue(),
                        shapeArray,
                        shapeArray.length,
                        deviceType,
                        deviceId,
                        delay,
                        dtype.ordinal(),
                        auxDTypes.length,
                        auxDTypesInt,
                        auxNDims,
                        auxShapesInt,
                        ref));
        Pointer pointer = ref.getValue();
        REFS.recycle(ref);
        return pointer;
    }

    public static void ndArraySyncCopyFromNdArray(MxNDArray dest, MxNDArray src, int location) {
        checkCall(LIB.MXNDArraySyncCopyFromNDArray(dest.getHandle(), src.getHandle(), location));
    }

    /* Need tests
    public static Pointer loadFromBytes(byte[] buf, int offset, int size) {
        Memory memory = new Memory(size);
        memory.write(0, buf, offset, size);

        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXNDArrayLoadFromRawBytes(memory, new NativeSize(size), ref));

        return ref.getValue();
    }

    public static void saveNdArray(String file, Pointer[] ndArrays, String[] keys) {
        PointerArray array = new PointerArray(ndArrays);
        checkCall(LIB.MXNDArraySave(file, ndArrays.length, array, keys));
    }
     */

    public static NDList loadNdArray(MxNDManager manager, Path path, Device device) {
        IntBuffer handlesSize = IntBuffer.allocate(1);
        PointerByReference handlesRef = REFS.acquire();
        PointerByReference namesRef = REFS.acquire();
        IntBuffer namesSize = IntBuffer.allocate(1);
        checkCall(LIB.MXNDArrayLoad(path.toString(), handlesSize, handlesRef, namesSize, namesRef));
        int ndArrayCount = handlesSize.get();
        int nameCount = namesSize.get();
        if (nameCount > 0 && ndArrayCount != nameCount) {
            throw new IllegalStateException(
                    "Mismatch between names and arrays in checkpoint file: " + path.toString());
        }
        Pointer[] handles = handlesRef.getValue().getPointerArray(0, ndArrayCount);
        NDList ndList = new NDList();
        if (nameCount == 0) {
            for (Pointer handle : handles) {
                ndList.add(manager.create(handle));
            }
        } else {
            String[] names = namesRef.getValue().getStringArray(0, nameCount);
            for (int i = 0; i < ndArrayCount; i++) {
                NDArray array = manager.create(handles[i]);
                array.setName(names[i]);
                ndList.add(array);
            }
        }

        REFS.recycle(namesRef);
        REFS.recycle(handlesRef);

        // MXNet always load NDArray on CPU
        if (Device.cpu().equals(device)) {
            return ndList;
        }

        NDList ret = ndList.toDevice(device, true);
        ndList.close();
        return ret;
    }

    /* Need tests
    public static ByteBuffer readBytes(Pointer ndArray) {
        NativeSizeByReference size = new NativeSizeByReference();
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXNDArraySaveRawBytes(ndArray, size, ref));

        return ref.getValue().getByteBuffer(0, size.getValue().longValue());
    }
     */

    public static void freeNdArray(Pointer ndArray) {
        checkNDArray(ndArray, "free");
        checkCall(LIB.MXNDArrayFree(ndArray));
    }

    public static void waitToRead(Pointer ndArray) {
        checkNDArray(ndArray, "wait to read");
        checkCall(LIB.MXNDArrayWaitToRead(ndArray));
    }

    public static void waitToWrite(Pointer ndArray) {
        checkNDArray(ndArray, "wait to write");
        checkCall(LIB.MXNDArrayWaitToWrite(ndArray));
    }

    public static void waitAll() {
        checkCall(LIB.MXNDArrayWaitAll());
    }

    public static void syncCopyToCPU(Pointer ndArray, Pointer data, int len) {
        NativeSize size = new NativeSize(len);
        checkNDArray(ndArray, "copy from");
        checkNDArray(data, "copy to");
        checkCall(LIB.MXNDArraySyncCopyToCPU(ndArray, data, size));
    }

    public static void syncCopyFromCPU(Pointer ndArray, Buffer data, int len) {
        NativeSize size = new NativeSize(len);
        Pointer pointer = Native.getDirectBufferPointer(data);
        checkCall(LIB.MXNDArraySyncCopyFromCPU(ndArray, pointer, size));
    }

    public static PairList imperativeInvoke(
            Pointer function, NDArray[] src, NDArray[] dest, PairList params) {
        String[] keys;
        String[] values;
        if (params == null) {
            keys = Utils.EMPTY_ARRAY;
            values = Utils.EMPTY_ARRAY;
        } else {
            keys = params.keyArray(Utils.EMPTY_ARRAY);
            values = params.values().stream().map(Object::toString).toArray(String[]::new);
        }
        StringArray keyArray = StringArray.of(keys);
        StringArray valueArray = StringArray.of(values);
        PointerArray srcArray = toPointerArray(src);
        PointerArray destArray = toPointerArray(dest);
        PointerByReference destRef = REFS.acquire();
        destRef.setValue(destArray);
        PointerByReference destSType = REFS.acquire();
        IntBuffer numOutputs = IntBuffer.allocate(1);
        numOutputs.put(0, 1);

        checkCall(
                LIB.MXImperativeInvokeEx(
                        function,
                        src.length,
                        srcArray,
                        numOutputs,
                        destRef,
                        keys.length,
                        keyArray,
                        valueArray,
                        destSType));
        int numOfOutputs = numOutputs.get(0);
        Pointer[] ptrArray = destRef.getValue().getPointerArray(0, numOfOutputs);
        int[] sTypes = destSType.getValue().getIntArray(0, numOfOutputs);
        PairList pairList = new PairList<>();
        for (int i = 0; i < numOfOutputs; i++) {
            pairList.add(ptrArray[i], SparseFormat.fromValue(sTypes[i]));
        }
        REFS.recycle(destRef);
        REFS.recycle(destSType);
        srcArray.recycle();
        keyArray.recycle();
        valueArray.recycle();

        if (destArray != null) {
            destArray.recycle();
        }
        return pairList;
    }

    public static SparseFormat getStorageType(Pointer ndArray) {
        IntBuffer type = IntBuffer.allocate(1);
        checkNDArray(ndArray, "get the storage type of");
        checkCall(LIB.MXNDArrayGetStorageType(ndArray, type));
        return SparseFormat.fromValue(type.get());
    }

    public static Device getDevice(Pointer ndArray) {
        IntBuffer deviceType = IntBuffer.allocate(1);
        IntBuffer deviceId = IntBuffer.allocate(1);
        checkNDArray(ndArray, "get the device of");
        checkCall(LIB.MXNDArrayGetContext(ndArray, deviceType, deviceId));
        String deviceTypeStr = MxDeviceType.fromDeviceType(deviceType.get(0));
        // CPU is special case which don't have device id
        return Device.of(deviceTypeStr, deviceId.get(0));
    }

    public static Shape getShape(Pointer ndArray) {
        IntBuffer dim = IntBuffer.allocate(1);
        PointerByReference ref = REFS.acquire();
        checkNDArray(ndArray, "get the shape of");
        checkCall(LIB.MXNDArrayGetShapeEx64(ndArray, dim, ref));
        int nDim = dim.get();
        if (nDim == 0) {
            REFS.recycle(ref);
            return new Shape();
        }
        long[] shape = ref.getValue().getLongArray(0, nDim);
        REFS.recycle(ref);
        return new Shape(shape);
    }

    public static DataType getDataType(Pointer ndArray) {
        IntBuffer dataType = IntBuffer.allocate(1);
        checkNDArray(ndArray, "get the data type of");
        checkCall(LIB.MXNDArrayGetDType(ndArray, dataType));
        return DataType.values()[dataType.get()];
    }

    /* Need tests
    public static DataType getAuxType(Pointer ndArray, int index) {
        IntBuffer dataType = IntBuffer.allocate(1);
        checkCall(LIB.MXNDArrayGetAuxType(ndArray, index, dataType));
        return DataType.values()[dataType.get()];
    }

    public static Pointer getAuxNdArray(Pointer ndArray, int index) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXNDArrayGetAuxNDArray(ndArray, index, ref));
        return ref.getValue();
    }

    public static Pointer getDataNdArray(Pointer ndArray) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXNDArrayGetDataNDArray(ndArray, ref));
        return ref.getValue();
    }

    public static Pointer getGrad(Pointer ndArray) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXNDArrayGetGrad(ndArray, ref));
        return ref.getValue();
    }

    public static Pointer reshape(Pointer ndArray, long[] dims, boolean reverse) {
        PointerByReference ref = new PointerByReference();
        byte reverseByte = reverse ? (byte) 1 : 0;
        checkCall(
                LIB.MXNDArrayReshape64(
                        ndArray, dims.length, LongBuffer.wrap(dims), reverseByte, ref));
        return ref.getValue();
    } */

    /////////////////////////////////
    // MxGradientCollector
    /////////////////////////////////
    public static boolean autogradSetIsRecording(boolean isRecording) {
        IntBuffer prev = IntBuffer.allocate(1);
        checkCall(LIB.MXAutogradSetIsRecording(isRecording ? 1 : 0, prev));
        return prev.get(0) == 1;
    }

    public static boolean autogradSetTraining(boolean isTraining) {
        IntBuffer prev = IntBuffer.allocate(1);
        checkCall(LIB.MXAutogradSetIsTraining(isTraining ? 1 : 0, prev));
        return prev.get(0) == 1;
    }

    public static boolean autogradIsRecording() {
        ByteBuffer isRecording = ByteBuffer.allocate(1);
        checkCall(LIB.MXAutogradIsRecording(isRecording));
        return isRecording.get(0) == 1;
    }

    public static boolean autogradIsTraining() {
        ByteBuffer isTraining = ByteBuffer.allocate(1);
        checkCall(LIB.MXAutogradIsTraining(isTraining));
        return isTraining.get(0) == 1;
    }

    public static void autogradMarkVariables(
            int numVar, Pointer varHandles, IntBuffer reqsArray, Pointer gradHandles) {
        PointerByReference varRef = REFS.acquire();
        PointerByReference gradRef = REFS.acquire();
        varRef.setValue(varHandles);
        gradRef.setValue(gradHandles);
        checkCall(LIB.MXAutogradMarkVariables(numVar, varRef, reqsArray, gradRef));
        REFS.recycle(varRef);
        REFS.recycle(gradRef);
    }

    public static void autogradBackward(NDList array, int retainGraph) {
        PointerByReference ref = REFS.acquire();
        PointerArray pa = toPointerArray(array);
        checkCall(LIB.MXAutogradBackward(array.size(), pa, ref, retainGraph));
        REFS.recycle(ref);
        pa.recycle();
    }

    public static void autogradBackwardExecute(
            int numOutput,
            NDList array,
            NDArray outgrad,
            int numVariables,
            Pointer varHandles,
            int retainGraph,
            int createGraph,
            int isTrain,
            Pointer gradHandles,
            Pointer gradSparseFormat) {
        PointerByReference varRef = REFS.acquire();
        PointerByReference gradRef = REFS.acquire();
        PointerByReference gradSparseFormatRef = REFS.acquire();
        varRef.setValue(varHandles);
        gradRef.setValue(gradHandles);
        gradSparseFormatRef.setValue(gradSparseFormat);
        PointerArray inputHandles = toPointerArray(array);
        PointerArray ogradHandles = PointerArray.of();

        checkCall(
                LIB.MXAutogradBackwardEx(
                        numOutput,
                        inputHandles,
                        ogradHandles,
                        numVariables,
                        varRef,
                        retainGraph,
                        createGraph,
                        isTrain,
                        gradRef,
                        gradSparseFormatRef));
        REFS.recycle(varRef);
        REFS.recycle(gradRef);
        REFS.recycle(gradSparseFormatRef);
        inputHandles.recycle();
        ogradHandles.recycle();
    }

    public static Pointer autogradGetSymbol(NDArray array) {
        Pointer handle = ((MxNDArray) array).getHandle();
        PointerByReference out = REFS.acquire();
        checkCall(LIB.MXAutogradGetSymbol(handle, out));
        Pointer pointer = out.getValue();
        REFS.recycle(out);
        return pointer;
    }

    public static int isNumpyMode() {
        IntBuffer ret = IntBuffer.allocate(1);
        checkCall(LIB.MXIsNumpyShape(ret));
        return ret.get();
    }

    public static void setNumpyMode(NumpyMode mode) {
        IntBuffer ret = IntBuffer.allocate(1);
        checkCall(LIB.MXSetIsNumpyShape(mode.ordinal(), ret));
    }

    public static Pointer getGradient(Pointer handle) {
        PointerByReference ref = REFS.acquire();
        checkNDArray(handle, "get the gradient for");
        checkCall(LIB.MXNDArrayGetGrad(handle, ref));
        Pointer pointer = ref.getValue();
        REFS.recycle(ref);
        return pointer;
    }

    public static Pointer parameterStoreCreate(String type) {
        PointerByReference ref = REFS.acquire();
        checkCall(LIB.MXKVStoreCreate(type, ref));
        Pointer pointer = ref.getValue();
        REFS.recycle(ref);
        return pointer;
    }

    public static void parameterStoreClose(Pointer handle) {
        checkCall(LIB.MXKVStoreFree(handle));
    }

    public static void parameterStoreInit(Pointer handle, int num, String[] keys, NDList vals) {
        checkNDArray(handle, "initialize the parameter store with");
        PointerArray pa = toPointerArray(vals);
        checkCall(LIB.MXKVStoreInitEx(handle, num, keys, pa));
        pa.recycle();
    }

    public static void parameterStorePush(
            Pointer handle, int num, String[] keys, NDList vals, int priority) {
        checkNDArray(handle, "push to the parameter store with");
        PointerArray pa = toPointerArray(vals);
        checkCall(LIB.MXKVStorePushEx(handle, num, keys, pa, priority));
        pa.recycle();
    }

    public static void parameterStorePull(
            Pointer handle, int num, int[] keys, NDList vals, int priority) {
        checkNDArray(handle, "pull from the parameter store with");
        PointerArray pa = toPointerArray(vals);
        checkCall(LIB.MXKVStorePull(handle, num, keys, pa, priority));
        pa.recycle();
    }

    public static void parameterStorePull(
            Pointer handle, int num, String[] keys, NDList vals, int priority) {
        checkNDArray(handle, "pull from the parameter store with");
        PointerArray pa = toPointerArray(vals);
        checkCall(LIB.MXKVStorePullEx(handle, num, keys, pa, priority));
        pa.recycle();
    }

    public static void parameterStorePushPull(
            Pointer handle,
            int inputNum,
            String[] inputKeys,
            int outputNum,
            String[] outputKey,
            NDList inputs,
            NDList outputs,
            int priority) {
        checkNDArray(handle, "push from the parameter store with");
        PointerArray inputHandles = toPointerArray(inputs);
        PointerArray outputHandles = toPointerArray(outputs);

        checkCall(
                LIB.MXKVStorePushPullEx(
                        handle,
                        inputNum,
                        inputKeys,
                        outputNum,
                        outputKey,
                        inputHandles,
                        outputHandles,
                        priority));
        inputHandles.recycle();
        outputHandles.recycle();
    }

    public static void parameterStoreSetUpdater(
            Pointer handle,
            MxnetLibrary.MXKVStoreUpdater updater,
            MxnetLibrary.MXKVStoreStrUpdater stringUpdater,
            Pointer updaterHandle) {
        checkCall(LIB.MXKVStoreSetUpdaterEx(handle, updater, stringUpdater, updaterHandle));
    }

    public static void parameterStoreSetUpdater(
            Pointer handle, MxnetLibrary.MXKVStoreUpdater updater, Pointer updaterHandle) {
        checkCall(LIB.MXKVStoreSetUpdater(handle, updater, updaterHandle));
    }

    /*
    int MXInitPSEnv(int num_vars, String keys[], String vals[]);

    int MXKVStoreSetGradientCompression(Pointer handle, int num_params, String keys[],
                                        String vals[]);

    int MXKVStorePullWithSparse(Pointer handle, int num, int keys[], PointerByReference vals,
                                int priority, byte ignore_sparse);

    int MXKVStorePullWithSparseEx(Pointer handle, int num, String keys[], PointerByReference vals,
                                  int priority, byte ignore_sparse);


    int MXKVStorePullRowSparse(Pointer handle, int num, int keys[], PointerByReference vals,
                               PointerByReference row_ids, int priority);

    int MXKVStorePullRowSparseEx(Pointer handle, int num, String keys[], PointerByReference vals,
                                 PointerByReference row_ids, int priority);

    int MXKVStoreGetType(Pointer handle, String type[]);


    int MXKVStoreGetRank(Pointer handle, IntBuffer ret);

    int MXKVStoreGetGroupSize(Pointer handle, IntBuffer ret);

    int MXKVStoreIsWorkerNode(IntBuffer ret);

    int MXKVStoreIsServerNode(IntBuffer ret);

    int MXKVStoreIsSchedulerNode(IntBuffer ret);


    int MXKVStoreBarrier(Pointer handle);


    int MXKVStoreSetBarrierBeforeExit(Pointer handle, int barrier_before_exit);


    int MXKVStoreRunServer(Pointer handle, MxnetLibrary.MXKVStoreServerController controller,
                           Pointer controller_handle);


    int MXKVStoreSendCommmandToServers(Pointer handle, int cmd_id, String cmd_body);

    int MXKVStoreGetNumDeadNode(Pointer handle, int node_id, IntBuffer number, int timeout_sec);
     */
    /*
    int MXImperativeInvokeEx(Pointer creator, int num_inputs, PointerByReference inputs,
                             IntBuffer num_outputs, PointerByReference outputs, int num_params,
                             String param_keys[], String param_vals[],
                             PointerByReference out_stypes);

    int MXNDArraySyncCopyFromCPU(Pointer handle, Pointer data, NativeSize size);

    int MXNDArraySyncCopyFromNDArray(Pointer handle_dst, Pointer handle_src, int i);

    int MXNDArraySyncCheckFormat(Pointer handle, byte full_check);


    int MXNDArrayReshape(Pointer handle, int ndim, IntBuffer dims, PointerByReference out);

    int MXNDArrayReshape64(Pointer handle, int ndim, LongBuffer dims, byte reverse,
                           PointerByReference out);

    int MXNDArrayGetData(Pointer handle, PointerByReference out_pdata);

    int MXNDArrayToDLPack(Pointer handle, PointerByReference out_dlpack);

    int MXNDArrayFromDLPack(Pointer dlpack, PointerByReference out_handle);

    int MXNDArrayCallDLPackDeleter(Pointer dlpack);

    int MXNDArrayGetDType(Pointer handle, IntBuffer out_dtype);

    int MXNDArrayGetAuxType(Pointer handle, int i, IntBuffer out_type);

    int MXNDArrayGetAuxNDArray(Pointer handle, int i, PointerByReference out);

    int MXNDArrayGetDataNDArray(Pointer handle, PointerByReference out);

    int MXNDArrayGetContext(Pointer handle, IntBuffer out_dev_type, IntBuffer out_dev_id);
    */
    public static Pointer detachGradient(Pointer handle) {
        PointerByReference ref = REFS.acquire();
        checkCall(LIB.MXNDArrayDetach(handle, ref));
        Pointer pointer = ref.getValue();
        REFS.recycle(ref);
        return pointer;
    }

    /*
    int MXNDArraySetGradState(Pointer handle, int state);

    int MXNDArrayGetGradState(Pointer handle, IntBuffer out);

    int MXListFunctions(IntBuffer out_size, PointerByReference out_array);


    int MXAutogradComputeGradient(int num_output, PointerByReference output_handles);


    int MXAutogradGetSymbol(Pointer handle, PointerByReference out);


    int MXCreateCachedOp(Pointer handle, PointerByReference out);


    int MXCreateCachedOpEx(Pointer handle, int num_flags, String keys[], String vals[],
                           PointerByReference out);


    int MXFreeCachedOp(Pointer handle);


    int MXInvokeCachedOp(Pointer handle, int num_inputs, PointerByReference inputs,
                         IntBuffer num_outputs, PointerByReference outputs);

    int MXInvokeCachedOpEx(Pointer handle, int num_inputs, PointerByReference inputs,
                           IntBuffer num_outputs, PointerByReference outputs,
                           PointerByReference out_stypes);


    int MXListAllOpNames(IntBuffer out_size, PointerByReference out_array);
    */

    /////////////////////////////////
    // MXNet Symbols
    /////////////////////////////////

    public static Pointer getSymbolOutput(Pointer symbol, int index) {
        PointerByReference ref = REFS.acquire();
        checkCall(LIB.MXSymbolGetOutput(symbol, index, ref));
        Pointer pointer = ref.getValue();
        REFS.recycle(ref);
        return pointer;
    }

    public static String[] listSymbolOutputs(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = REFS.acquire();

        checkCall(LIB.MXSymbolListOutputs(symbol, size, ref));
        String[] ret = toStringArray(ref, size.get());
        REFS.recycle(ref);
        return ret;
    }

    /* Need tests
    public static String symbolToJson(Pointer symbol) {
        String[] out = new String[1];
        checkCall(LIB.MXSymbolSaveToJSON(symbol, out));
        return out[0];
    }
     */

    public static void freeSymbol(Pointer symbol) {
        checkCall(LIB.MXSymbolFree(symbol));
    }

    /* Need tests
    public static void saveSymbol(Pointer symbol, String path) {
        checkCall(LIB.MXSymbolSaveToFile(symbol, path));
    }

    public static Pointer copySymbol(Pointer symbol) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolCopy(symbol, ref));
        return ref.getValue();
    }

    public static String getSymbolDebugString(Pointer symbol) {
        String[] out = new String[1];
        checkCall(LIB.MXSymbolPrint(symbol, out));
        return out[0];
    }

    public static String getSymbolName(Pointer symbol) {
        String[] out = new String[1];
        IntBuffer success = IntBuffer.allocate(1);
        checkCall(LIB.MXSymbolGetName(symbol, out, success));
        if (success.get() == 1) {
            return out[0];
        }
        return null;
    }

    public static String getSymbolAttr(Pointer symbol, String key) {
        String[] out = new String[1];
        IntBuffer success = IntBuffer.allocate(1);
        checkCall(LIB.MXSymbolGetAttr(symbol, key, out, success));
        if (success.get() == 1) {
            return out[0];
        }
        return null;
    }

    public static void setSymbolAttr(Pointer symbol, String key, String value) {
        checkCall(LIB.MXSymbolSetAttr(symbol, key, value));
    }

    public static PairList listSymbolAttr(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();

        checkCall(LIB.MXSymbolListAttr(symbol, size, ref));

        return toPairList(ref, size.get());
    }

    public static PairList listSymbolAttrShallow(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();

        checkCall(LIB.MXSymbolListAttrShallow(symbol, size, ref));

        return toPairList(ref, size.get());
    }
     */

    public static String[] listSymbolNames(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = REFS.acquire();

        checkCall(LIB.NNSymbolListInputNames(symbol, 0, size, ref));

        String[] ret = toStringArray(ref, size.get());
        REFS.recycle(ref);
        return ret;
    }

    public static String[] listSymbolArguments(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = REFS.acquire();

        checkCall(LIB.MXSymbolListArguments(symbol, size, ref));

        String[] ret = toStringArray(ref, size.get());
        REFS.recycle(ref);
        return ret;
    }

    public static String[] listSymbolAuxiliaryStates(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = REFS.acquire();

        checkCall(LIB.MXSymbolListAuxiliaryStates(symbol, size, ref));

        String[] ret = toStringArray(ref, size.get());
        REFS.recycle(ref);
        return ret;
    }

    public static Pointer getSymbolInternals(Pointer symbol) {
        PointerByReference ref = REFS.acquire();
        checkCall(LIB.MXSymbolGetInternals(symbol, ref));
        Pointer pointer = ref.getValue();
        REFS.recycle(ref);
        return pointer;
    }

    /* Need tests
    public static String[] listSymbolArguments(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();

        checkCall(LIB.MXSymbolListArguments(symbol, size, ref));

        return toStringArray(ref, size.get());
    }

    public static int getSymbolNumOutputs(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        checkCall(LIB.MXSymbolGetNumOutputs(symbol, size));
        return size.get();
    }

    public static Pointer getSymbolInternals(Pointer symbol) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolGetInternals(symbol, ref));
        return ref.getValue();
    }

    public static String getSymbolChildren(Pointer symbol) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolGetChildren(symbol, ref));
        return ref.getValue().getString(0, StandardCharsets.UTF_8.name());
    }

    public static String[] listSymbolAuxiliaryStates(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();

        checkCall(LIB.MXSymbolListAuxiliaryStates(symbol, size, ref));

        return toStringArray(ref, size.get());
    }

    public static Pointer[] listAtomicSymbolCreators() {
        IntBuffer outSize = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();

        checkCall(LIB.MXSymbolListAtomicSymbolCreators(outSize, ref));

        int size = outSize.get();
        return ref.getValue().getPointerArray(0, size);
    }

    public static String getAtomicSymbolName(Pointer symbol) {
        String[] ret = new String[1];
        checkCall(LIB.MXSymbolGetAtomicSymbolName(symbol, ret));
        return ret[0];
    }

    public static String getInputSymbols(Pointer symbol) {
        IntBuffer size = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolGetInputSymbols(symbol, ref, size));
        return ref.getValue().getString(0, StandardCharsets.UTF_8.name());
    }

    public static String cutSubgraph(Pointer symbol) {
        IntBuffer inputSize = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolCutSubgraph(symbol, ref, inputSize));
        return ref.getValue().getString(0, StandardCharsets.UTF_8.name());
    }

    public static Pointer createAtomicSymbol(Pointer symbol, String[] keys, String[] values) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolCreateAtomicSymbol(symbol, keys.length, keys, values, ref));
        return ref.getValue();
    }

    public static Pointer createVariable(String name) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolCreateVariable(name, ref));
        return ref.getValue();
    }

    public static Pointer createGroup(int numOfSymbols, Pointer symbols) {
        PointerByReference symbolsRef = new PointerByReference(symbols);
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolCreateGroup(numOfSymbols, symbolsRef, ref));
        return ref.getValue();
    }
     */

    public static Pointer createSymbolFromFile(String path) {
        PointerByReference ref = REFS.acquire();
        checkCall(LIB.MXSymbolCreateFromFile(path, ref));
        Pointer pointer = ref.getValue();
        REFS.recycle(ref);
        return pointer;
    }

    public static Pointer createSymbolFromString(String json) {
        PointerByReference ref = REFS.acquire();
        checkCall(LIB.MXSymbolCreateFromJSON(json, ref));
        Pointer pointer = ref.getValue();
        REFS.recycle(ref);
        return pointer;
    }

    public static String getSymbolString(Pointer symbol) {
        String[] holder = new String[1];
        checkCall(LIB.MXSymbolSaveToJSON(symbol, holder));
        return holder[0];
    }

    private static List recoverShape(
            NativeSizeByReference size, PointerByReference nDim, PointerByReference data) {
        int shapeLength = (int) size.getValue().longValue();
        if (shapeLength == 0) {
            return new ArrayList<>();
        }
        int[] dims = nDim.getValue().getIntArray(0, shapeLength);
        int flattenedLength = 0;
        for (int dim : dims) {
            flattenedLength += dim;
        }
        long[] flattenedShapes = data.getValue().getPointer(0).getLongArray(0, flattenedLength);
        int idx = 0;
        List result = new ArrayList<>();
        for (int dim : dims) {
            long[] shape = new long[dim];
            System.arraycopy(flattenedShapes, idx, shape, 0, dim);
            idx += dim;
            result.add(new Shape(shape));
        }
        return result;
    }

    public static List> inferShape(Symbol symbol, PairList args) {
        Pointer handler = symbol.getHandle();
        int numArgs = args.size();
        String[] keys = args.keys().toArray(Utils.EMPTY_ARRAY);
        // the following two is also the representation of
        // CSR NDArray
        long[] indPtr = new long[numArgs + 1];
        Shape flattened = new Shape();
        indPtr[0] = 0;
        for (int i = 0; i < args.size(); i++) {
            Shape shape = args.valueAt(i);
            indPtr[i + 1] = shape.dimension();
            flattened = flattened.addAll(shape);
        }
        long[] flattenedShapeArray = flattened.getShape();

        NativeSizeByReference inShapeSize = new NativeSizeByReference();
        PointerByReference inShapeNDim = REFS.acquire();
        PointerByReference inShapeData = REFS.acquire();
        NativeSizeByReference outShapeSize = new NativeSizeByReference();
        PointerByReference outShapeNDim = REFS.acquire();
        PointerByReference outShapeData = REFS.acquire();
        NativeSizeByReference auxShapeSize = new NativeSizeByReference();
        PointerByReference auxShapeNDim = REFS.acquire();
        PointerByReference auxShapeData = REFS.acquire();
        IntBuffer complete = IntBuffer.allocate(1);
        checkCall(
                LIB.MXSymbolInferShapeEx64(
                        handler,
                        numArgs,
                        keys,
                        indPtr,
                        flattenedShapeArray,
                        inShapeSize,
                        inShapeNDim,
                        inShapeData,
                        outShapeSize,
                        outShapeNDim,
                        outShapeData,
                        auxShapeSize,
                        auxShapeNDim,
                        auxShapeData,
                        complete));
        if (complete.get() != 0) {
            return Arrays.asList(
                    recoverShape(inShapeSize, inShapeNDim, inShapeData),
                    recoverShape(outShapeSize, outShapeNDim, outShapeData),
                    recoverShape(auxShapeSize, auxShapeNDim, auxShapeData));
        }
        throw new IllegalArgumentException("Cannot infer shape based on the data provided!");
    }

    public static void loadLib(String path, boolean verbose) {
        int intVerbose = verbose ? 1 : 0;
        checkCall(LIB.MXLoadLib(path, intVerbose));
    }

    public static Pointer optimizeFor(Symbol current, String backend, Device device) {
        // TODO: Support partition on parameters
        PointerByReference returnedSymbolHandle = REFS.acquire();
        // placeHolders
        PointerByReference[] placeHolders = {
            REFS.acquire(),
            REFS.acquire(),
            REFS.acquire(),
            REFS.acquire(),
            REFS.acquire(),
            REFS.acquire()
        };
        // there is no need to update parameters
        checkCall(
                LIB.MXOptimizeForBackend(
                        current.getHandle(),
                        backend,
                        MxDeviceType.toDeviceType(device),
                        returnedSymbolHandle,
                        0,
                        placeHolders[0],
                        0,
                        placeHolders[1],
                        0,
                        Utils.EMPTY_ARRAY,
                        Utils.EMPTY_ARRAY,
                        IntBuffer.allocate(1),
                        placeHolders[2],
                        placeHolders[3],
                        IntBuffer.allocate(1),
                        placeHolders[4],
                        placeHolders[5]));
        Pointer ptr = returnedSymbolHandle.getValue();
        REFS.recycle(returnedSymbolHandle);
        Arrays.stream(placeHolders).forEach(REFS::recycle);
        return ptr;
    }

    /* Need tests
    public static Pointer createSymbolFromJson(String json) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXSymbolCreateFromJSON(json, ref));
        return ref.getValue();
    }

    public static Pointer compose(Pointer symbol, String name, String[] keys) {
        PointerByReference ref = new PointerByReference();

        checkCall(LIB.MXSymbolCompose(symbol, name, keys.length, keys, ref));
        return ref.getValue();
    }

    public static Pointer grad(Pointer symbol, String name, int numWrt, String[] wrt) {
        PointerByReference ref = new PointerByReference();

        checkCall(LIB.MXSymbolCompose(symbol, name, numWrt, wrt, ref));
        return ref.getValue();
    }

    public static Shape[] inferShape(Pointer symbol, String[] keys) {
        IntBuffer argIndex = IntBuffer.allocate(1);
        IntBuffer argShapeData = IntBuffer.allocate(1);
        IntBuffer inShapeSize = IntBuffer.allocate(1);
        PointerByReference inShapeNDim = new PointerByReference();
        PointerByReference inShapeData = new PointerByReference();
        IntBuffer outShapeSize = IntBuffer.allocate(1);
        PointerByReference outShapeNDim = new PointerByReference();
        PointerByReference outShapeData = new PointerByReference();
        IntBuffer auxShapeSize = IntBuffer.allocate(1);
        PointerByReference auxShapeNDim = new PointerByReference();
        PointerByReference auxShapeData = new PointerByReference();
        IntBuffer complete = IntBuffer.allocate(1);

        checkCall(
                LIB.MXSymbolInferShape(
                        symbol,
                        keys.length,
                        keys,
                        argIndex.array(),
                        argShapeData.array(),
                        inShapeSize,
                        inShapeNDim,
                        inShapeData,
                        outShapeSize,
                        outShapeNDim,
                        outShapeData,
                        auxShapeSize,
                        auxShapeNDim,
                        auxShapeData,
                        complete));
        if (complete.get() == 1) {
            Shape[] ret = new Shape[keys.length];
            // TODO: add implementation
            return ret; // NOPMD
        }
        return null;
    }

    public static Pointer inferType(Pointer symbol, String[] keys) {
        int[] argTypeData = new int[1];
        IntBuffer inTypeSize = IntBuffer.allocate(1);
        PointerByReference inTypeData = new PointerByReference();
        IntBuffer outTypeSize = IntBuffer.allocate(1);
        PointerByReference outTypeData = new PointerByReference();
        IntBuffer auxTypeSize = IntBuffer.allocate(1);
        PointerByReference auxTypeData = new PointerByReference();
        IntBuffer complete = IntBuffer.allocate(1);

        checkCall(
                LIB.MXSymbolInferType(
                        symbol,
                        keys.length,
                        keys,
                        argTypeData,
                        inTypeSize,
                        inTypeData,
                        outTypeSize,
                        outTypeData,
                        auxTypeSize,
                        auxTypeData,
                        complete));
        if (complete.get() == 1) {
            return outTypeData.getValue();
        }
        return null;
    }

    public static Pointer quantizeSymbol(
            Pointer symbol,
            String[] excludedSymbols,
            String[] offlineParams,
            String quantizedDType,
            byte calibQuantize) {
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXQuantizeSymbol(
                        symbol,
                        ref,
                        excludedSymbols.length,
                        excludedSymbols,
                        offlineParams.length,
                        offlineParams,
                        quantizedDType,
                        calibQuantize));
        return ref.getValue();
    }

    public static Pointer setCalibTableToQuantizedSymbol(
            Pointer symbol,
            String[] layerNames,
            FloatBuffer lowQuantiles,
            FloatBuffer highQuantiles) {
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXSetCalibTableToQuantizedSymbol(
                        symbol, layerNames.length, layerNames, lowQuantiles, highQuantiles, ref));
        return ref.getValue();
    }

    public static Pointer genBackendSubgraph(Pointer symbol, String backend) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXGenBackendSubgraph(symbol, backend, ref));
        return ref.getValue();
    }
     */

    /////////////////////////////////
    // MXNet Executors
    /////////////////////////////////

    /* Need tests
    public static void freeExecutor(Pointer executor) {
        checkCall(LIB.MXExecutorFree(executor));
    }

    public static String getExecutorDebugString(Pointer executor) {
        String[] ret = new String[1];
        checkCall(LIB.MXExecutorPrint(executor, ret));
        return ret[0];
    }

    public static void forward(Pointer executor, boolean isTrain) {
        checkCall(LIB.MXExecutorForward(executor, isTrain ? 1 : 0));
    }

    public static Pointer backward(Pointer executor, int length) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXExecutorBackward(executor, length, ref));
        return ref.getValue();
    }

    public static Pointer backwardEx(Pointer executor, int length, boolean isTrain) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXExecutorBackwardEx(executor, length, ref, isTrain ? 1 : 0));
        return ref.getValue();
    }

    public static NDArray[] getExecutorOutputs(MxNDManager manager, Pointer executor) {
        IntBuffer outSize = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXExecutorOutputs(executor, outSize, ref));
        int size = outSize.get();
        Pointer[] pointers = ref.getValue().getPointerArray(0, size);
        NDArray[] ndArrays = new NDArray[size];
        for (int i = 0; i < size; ++i) {
            ndArrays[i] = manager.create(pointers[i]);
        }
        return ndArrays;
    }

    public static Pointer bindExecutorSimple(
            Symbol symbol,
            Device device,
            String[] g2cKeys,
            int[] g2cDeviceTypes,
            int[] g2cDeviceIds,
            String[] argParams,
            String[] argParamGradReqs,
            String[] inputArgNames,
            IntBuffer inputShapeData,
            IntBuffer inputShapeIdx,
            String[] inputDataTypeNames,
            int[] inputDataTypes,
            String[] inputStorageTypeNames,
            int[] inputStorageTypes,
            String[] sharedArgParams,
            IntBuffer sharedBufferLen,
            String[] sharedBufferNames,
            PointerByReference sharedBufferHandles,
            PointerByReference updatedSharedBufferNames,
            PointerByReference updatedSharedBufferHandles,
            IntBuffer numInArgs,
            PointerByReference inArgs,
            PointerByReference argGrads,
            IntBuffer numAuxStates,
            PointerByReference auxStates) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);

        PointerByReference ref = new PointerByReference();

        checkCall(
                LIB.MXExecutorSimpleBind(
                        symbol.getHandle(),
                        deviceType,
                        deviceId,
                        g2cKeys == null ? 0 : g2cKeys.length,
                        g2cKeys,
                        g2cDeviceTypes,
                        g2cDeviceIds,
                        argParams.length,
                        argParams,
                        argParamGradReqs,
                        inputArgNames.length,
                        inputArgNames,
                        inputShapeData.array(),
                        inputShapeIdx.array(),
                        inputDataTypeNames.length,
                        inputDataTypeNames,
                        inputDataTypes,
                        inputStorageTypeNames == null ? 0 : inputStorageTypeNames.length,
                        inputStorageTypeNames,
                        inputStorageTypes,
                        sharedArgParams.length,
                        sharedArgParams,
                        sharedBufferLen,
                        sharedBufferNames,
                        sharedBufferHandles,
                        updatedSharedBufferNames,
                        updatedSharedBufferHandles,
                        numInArgs,
                        inArgs,
                        argGrads,
                        numAuxStates,
                        auxStates,
                        null,
                        ref));
        return ref.getValue();
    }

    public static Pointer bindExecutor(
            Pointer executor, Device device, int len, int auxStatesLen) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);
        PointerByReference inArgs = new PointerByReference();
        PointerByReference argGradStore = new PointerByReference();
        IntBuffer gradReqType = IntBuffer.allocate(1);
        PointerByReference auxStates = new PointerByReference();
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXExecutorBind(
                        executor,
                        deviceType,
                        deviceId,
                        len,
                        inArgs,
                        argGradStore,
                        gradReqType,
                        auxStatesLen,
                        auxStates,
                        ref));
        return ref.getValue();
    }

    public static Pointer bindExecutorX(
            Pointer executor,
            Device device,
            int len,
            int auxStatesLen,
            String[] keys,
            int[] deviceTypes,
            int[] deviceIds) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);
        PointerByReference inArgs = new PointerByReference();
        PointerByReference argGradStore = new PointerByReference();
        IntBuffer gradReqType = IntBuffer.allocate(1);
        PointerByReference auxStates = new PointerByReference();
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXExecutorBindX(
                        executor,
                        deviceType,
                        deviceId,
                        keys.length,
                        keys,
                        deviceTypes,
                        deviceIds,
                        len,
                        inArgs,
                        argGradStore,
                        gradReqType,
                        auxStatesLen,
                        auxStates,
                        ref));
        return ref.getValue();
    }

    public static Pointer bindExecutorEX(
            Pointer executor,
            Device device,
            int len,
            int auxStatesLen,
            String[] keys,
            int[] deviceTypes,
            int[] deviceIds,
            Pointer sharedExecutor) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);
        PointerByReference inArgs = new PointerByReference();
        PointerByReference argGradStore = new PointerByReference();
        IntBuffer gradReqType = IntBuffer.allocate(1);
        PointerByReference auxStates = new PointerByReference();
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXExecutorBindEX(
                        executor,
                        deviceType,
                        deviceId,
                        keys.length,
                        keys,
                        deviceTypes,
                        deviceIds,
                        len,
                        inArgs,
                        argGradStore,
                        gradReqType,
                        auxStatesLen,
                        auxStates,
                        sharedExecutor,
                        ref));
        return ref.getValue();
    }

    public static Pointer reshapeExecutor(
            boolean partialShaping,
            boolean allowUpSizing,
            Device device,
            String[] keys,
            int[] deviceTypes,
            int[] deviceIds,
            String[] providedArgShapeNames,
            IntBuffer providedArgShapeData,
            IntBuffer providedArgShapeIdx,
            IntBuffer numInArgs,
            PointerByReference inArgs,
            PointerByReference argGrads,
            IntBuffer numAuxStates,
            PointerByReference auxStates,
            Pointer sharedExecutor) {
        int deviceId = device.getDeviceId();
        int deviceType = DeviceType.toDeviceType(device);
        PointerByReference ref = new PointerByReference();
        checkCall(
                LIB.MXExecutorReshape(
                        partialShaping ? 1 : 0,
                        allowUpSizing ? 1 : 0,
                        deviceType,
                        deviceId,
                        keys.length,
                        keys,
                        deviceTypes,
                        deviceIds,
                        providedArgShapeNames.length,
                        providedArgShapeNames,
                        providedArgShapeData.array(),
                        providedArgShapeIdx.array(),
                        numInArgs,
                        inArgs,
                        argGrads,
                        numAuxStates,
                        auxStates,
                        sharedExecutor,
                        ref));
        return ref.getValue();
    }

    public static Pointer getOptimizedSymbol(Pointer executor) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXExecutorGetOptimizedSymbol(executor, ref));
        return ref.getValue();
    }

    public static void setMonitorCallback(
            Pointer executor,
            MxnetLibrary.ExecutorMonitorCallback callback,
            Pointer callbackHandle) {
        checkCall(LIB.MXExecutorSetMonitorCallback(executor, callback, callbackHandle));
    }
     */

    /////////////////////////////////
    // MXNet Executors
    /////////////////////////////////

    /*
    public static Pointer[] listDataIters() {
        IntBuffer outSize = IntBuffer.allocate(1);
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXListDataIters(outSize, ref));
        return ref.getValue().getPointerArray(0, outSize.get());
    }

    public static Pointer createIter(Pointer iter, String[] keys, String[] values) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXDataIterCreateIter(iter, keys.length, keys, values, ref));
        return ref.getValue();
    }

    public static String getIterInfo(Pointer iter) {
        String[] name = new String[1];
        String[] description = new String[1];
        IntBuffer numArgs = IntBuffer.allocate(1);
        PointerByReference argNames = new PointerByReference();
        PointerByReference argTypes = new PointerByReference();
        PointerByReference argDesc = new PointerByReference();
        checkCall(
                LIB.MXDataIterGetIterInfo(
                        iter, name, description, numArgs, argNames, argTypes, argDesc));
        return name[0];
    }

    public static void freeDataIter(Pointer iter) {
        checkCall(LIB.MXDataIterFree(iter));
    }

    public static int next(Pointer iter) {
        IntBuffer ret = IntBuffer.allocate(1);
        checkCall(LIB.MXDataIterNext(iter, ret));
        return ret.get();
    }

    public static void beforeFirst(Pointer iter) {
        checkCall(LIB.MXDataIterBeforeFirst(iter));
    }

    public static Pointer getData(Pointer iter) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXDataIterGetData(iter, ref));
        return ref.getValue();
    }

    public static Pointer getIndex(Pointer iter) {
        LongBuffer outSize = LongBuffer.wrap(new long[1]);
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXDataIterGetIndex(iter, ref, outSize));
        return ref.getValue();
    }

    public static int getPadNum(Pointer iter) {
        IntBuffer outSize = IntBuffer.allocate(1);
        checkCall(LIB.MXDataIterGetPadNum(iter, outSize));
        return outSize.get();
    }

    public static String getDataIterLabel(Pointer iter) {
        PointerByReference ref = new PointerByReference();
        checkCall(LIB.MXDataIterGetLabel(iter, ref));
        return ref.getValue().getString(0, StandardCharsets.UTF_8.name());
    }
     */

    /*



    int MXRecordIOWriterCreate(String uri, PointerByReference out);


    int MXRecordIOWriterFree(Pointer handle);


    int MXRecordIOWriterWriteRecord(Pointer handle, String buf, NativeSize size);


    int MXRecordIOWriterTell(Pointer handle, NativeSizeByReference pos);


    int MXRecordIOReaderCreate(String uri, PointerByReference out);


    int MXRecordIOReaderFree(Pointer handle);


    int MXRecordIOReaderReadRecord(Pointer handle, String buf[], NativeSizeByReference size);


    int MXRecordIOReaderSeek(Pointer handle, NativeSize pos);


    int MXRecordIOReaderTell(Pointer handle, NativeSizeByReference pos);


    int MXRtcCreate(ByteBuffer name, int num_input, int num_output, PointerByReference input_names,
                    PointerByReference output_names, PointerByReference inputs,
                    PointerByReference outputs, ByteBuffer kernel, PointerByReference out);


    int MXRtcPush(Pointer handle, int num_input, int num_output, PointerByReference inputs,
                  PointerByReference outputs, int gridDimX, int gridDimY, int gridDimZ,
                  int blockDimX, int blockDimY, int blockDimZ);


    int MXRtcFree(Pointer handle);


    int MXCustomOpRegister(String op_type, MxnetLibrary.CustomOpPropCreator creator);


    int MXCustomFunctionRecord(int num_inputs, PointerByReference inputs, int num_outputs,
                               PointerByReference outputs, MXCallbackList callbacks);


    int MXRtcCudaModuleCreate(String source, int num_options, String options[], int num_exports,
                              String exports[], PointerByReference out);


    int MXRtcCudaModuleFree(Pointer handle);


    int MXRtcCudaKernelCreate(Pointer handle, String name, int num_args, IntBuffer is_ndarray,
                              IntBuffer is_const, IntBuffer arg_types, PointerByReference out);


    int MXRtcCudaKernelFree(Pointer handle);


    int MXRtcCudaKernelCall(Pointer handle, int dev_id, PointerByReference args, int grid_dim_x,
                            int grid_dim_y, int grid_dim_z, int block_dim_x, int block_dim_y,
                            int block_dim_z, int shared_mem);


    int MXNDArrayGetSharedMemHandle(Pointer handle, IntBuffer shared_pid, IntBuffer shared_id);


    int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, IntBuffer shape, int ndim,
                                     int dtype, PointerByReference out);
    */

    //////////////////////////////////
    // cached Op
    //////////////////////////////////

    /**
     * Creates cached op flags.
     *
     * 

data_indices : [0, 2, 4] Used to label input location, param_indices : [1, 3] Used to * label param location * * @param block the {@link MxSymbolBlock} that loaded in the backend * @param manager the NDManager used to create NDArray * @param training true if CachedOp is created to forward in traning otherwise, false * @return a CachedOp for inference */ public static CachedOp createCachedOp( MxSymbolBlock block, MxNDManager manager, boolean training) { Symbol symbol = block.getSymbol(); List parameters = block.getAllParameters(); // record data index in all inputs PairList dataIndices = new PairList<>(); // record parameter index in all inputs List paramIndices = new ArrayList<>(); int index = 0; for (Parameter parameter : parameters) { // We assume uninitialized parameters are data inputs if (parameter.isInitialized()) { paramIndices.add(index); } else { dataIndices.add(parameter.getName(), index); } ++index; } // Creating CachedOp Pointer symbolHandle = symbol.getHandle(); PointerByReference ref = REFS.acquire(); // static_alloc and static_shape are enabled by default String staticAlloc = "1"; String staticShape = "1"; if (!Boolean.parseBoolean(System.getProperty("ai.djl.mxnet.static_alloc", "true"))) { staticAlloc = "0"; } if (!Boolean.parseBoolean(System.getProperty("ai.djl.mxnet.static_shape", "true"))) { staticShape = "0"; } String[] keys = {"data_indices", "param_indices", "static_alloc", "static_shape"}; String[] values = { dataIndices.values().toString(), paramIndices.toString(), staticAlloc, staticShape }; checkCall(LIB.MXCreateCachedOpEx(symbolHandle, keys.length, keys, values, ref)); Pointer pointer = ref.getValue(); REFS.recycle(ref); return new CachedOp(pointer, manager, parameters, paramIndices, dataIndices); } public static void freeCachedOp(Pointer handle) { checkCall(LIB.MXFreeCachedOp(handle)); } public static MxNDArray[] cachedOpInvoke( MxNDManager manager, Pointer cachedOpHandle, MxNDArray[] inputs) { IntBuffer buf = IntBuffer.allocate(1); PointerArray array = toPointerArray(inputs); PointerByReference ref = REFS.acquire(); PointerByReference outSTypeRef = REFS.acquire(); checkCall( LIB.MXInvokeCachedOpEx( cachedOpHandle, inputs.length, array, buf, ref, outSTypeRef)); int numOutputs = buf.get(); Pointer[] ptrArray = ref.getValue().getPointerArray(0, numOutputs); int[] sTypes = outSTypeRef.getValue().getIntArray(0, numOutputs); MxNDArray[] output = new MxNDArray[numOutputs]; for (int i = 0; i < numOutputs; i++) { if (sTypes[i] != 0) { output[i] = manager.create(ptrArray[i], SparseFormat.fromValue(sTypes[i])); } else { output[i] = manager.create(ptrArray[i]); } } REFS.recycle(ref); REFS.recycle(outSTypeRef); array.recycle(); return output; } public static void checkCall(int ret) { if (ret != 0) { throw new EngineException("MXNet engine call failed: " + getLastError()); } } private static PointerArray toPointerArray(NDList vals) { Pointer[] valPointers = new Pointer[vals.size()]; for (int i = 0; i < vals.size(); i++) { valPointers[i] = ((MxNDArray) vals.get(i)).getHandle(); } return PointerArray.of(valPointers); } private static PointerArray toPointerArray(NDArray[] vals) { if (vals == null) { return null; } Pointer[] valPointers = new Pointer[vals.length]; for (int i = 0; i < vals.length; i++) { valPointers[i] = ((MxNDArray) vals[i]).getHandle(); } return PointerArray.of(valPointers); } private static void checkNDArray(Pointer pointer, String msg) { if (pointer == null) { throw new IllegalArgumentException( "Tried to " + msg + " an MXNet NDArray that was already closed"); } } private static String getLastError() { return LIB.MXGetLastError(); } private static String[] toStringArray(PointerByReference ref, int size) { if (size == 0) { return Utils.EMPTY_ARRAY; } Pointer[] pointers = ref.getValue().getPointerArray(0, size); String[] arr = new String[size]; for (int i = 0; i < size; ++i) { arr[i] = pointers[i].getString(0, StandardCharsets.UTF_8.name()); } return arr; } /* private static PairList toPairList(PointerByReference ref, int size) { Pointer[] pointers = ref.getValue().getPointerArray(0, size); List names = new ArrayList<>(size); List values = new ArrayList<>(size); for (Pointer pointer : pointers) { String[] pair = pointer.getStringArray(0, 2, StandardCharsets.UTF_8.name()); names.add(pair[0]); values.add(pair[1]); } return new PairList<>(names, values); } */ private static String getOpNamePrefix(String name) { for (String prefix : OP_NAME_PREFIX) { if (name.startsWith(prefix)) { return name.substring(prefix.length()); } } return name; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy