All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.mxnet.jna.FunctionInfo Maven / Gradle / Ivy

/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.mxnet.jna;

import ai.djl.Device;
import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.Trainer;
import ai.djl.util.PairList;

import com.sun.jna.Pointer;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;

/** A FunctionInfo represents an operator (ie function) within the MXNet Engine. */
public class FunctionInfo {

    private Pointer handle;
    private String name;
    private PairList arguments;

    private static final Logger logger = LoggerFactory.getLogger(Trainer.class);

    FunctionInfo(Pointer pointer, String functionName, PairList arguments) {
        this.handle = pointer;
        this.name = functionName;
        this.arguments = arguments;
    }

    /**
     * Calls an operator with the given arguments.
     *
     * @param manager the manager to attach the result to
     * @param src the input NDArray(s) to the operator
     * @param dest the destination NDArray(s) to be overwritten with the result of the operator
     * @param params the non-NDArray arguments to the operator. Should be a {@code PairList}
     * @return the error code or zero for no errors
     */
    public int invoke(
            NDManager manager, NDArray[] src, NDArray[] dest, PairList params) {
        checkDevices(src);
        checkDevices(dest);
        return JnaUtils.imperativeInvoke(handle, src, dest, params).size();
    }

    /**
     * Calls an operator with the given arguments.
     *
     * @param manager the manager to attach the result to
     * @param src the input NDArray(s) to the operator
     * @param params the non-NDArray arguments to the operator. Should be a {@code PairList}
     * @return the error code or zero for no errors
     */
    public NDArray[] invoke(NDManager manager, NDArray[] src, PairList params) {
        checkDevices(src);
        PairList pairList =
                JnaUtils.imperativeInvoke(handle, src, null, params);
        final MxNDManager mxManager = (MxNDManager) manager;
        return pairList.stream()
                .map(
                        pair -> {
                            if (pair.getValue() != SparseFormat.DENSE) {
                                return mxManager.create(pair.getKey(), pair.getValue());
                            }
                            return mxManager.create(pair.getKey());
                        })
                .toArray(MxNDArray[]::new);
    }

    /**
     * Returns the name of the operator.
     *
     * @return the name of the operator
     */
    public String getFunctionName() {
        return name;
    }

    /**
     * Returns the names of the params to the operator.
     *
     * @return the names of the params to the operator
     */
    public List getArgumentNames() {
        return arguments.keys();
    }

    /**
     * Returns the types of the operator arguments.
     *
     * @return the types of the operator arguments
     */
    public List getArgumentTypes() {
        return arguments.values();
    }

    private void checkDevices(NDArray[] src) {
        // check if all the NDArrays are in the same device
        if (logger.isDebugEnabled() && src.length > 1) {
            Device device = src[0].getDevice();
            for (int i = 1; i < src.length; ++i) {
                if (!device.equals(src[i].getDevice())) {
                    logger.warn(
                            "Please make sure all the NDArrays are in the same device. You can call"
                                    + " toDevice() to move the NDArray to the desired Device.");
                }
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy