All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.onnxruntime.runner.OnnxRuntimeRunner Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */
package org.nd4j.onnxruntime.runner;

import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.*;
import org.bytedeco.onnxruntime.*;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.onnxruntime.util.ONNXUtils;

import java.io.Closeable;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import static org.bytedeco.onnxruntime.global.onnxruntime.*;
import static org.nd4j.onnxruntime.util.ONNXUtils.getDataBuffer;
import static org.nd4j.onnxruntime.util.ONNXUtils.getTensor;

@Slf4j
public class OnnxRuntimeRunner implements Closeable  {
    private Session session;
    private RunOptions runOptions;
    private MemoryInfo memoryInfo;
    private AllocatorWithDefaultOptions allocator;
    private SessionOptions sessionOptions;
    private   static Env env;
    private Pointer bp;


    @Builder
    public OnnxRuntimeRunner(String modelUri) {
        if(env == null) {
            env = new Env(ONNXUtils.getOnnxLogLevelFromLogger(log), new BytePointer("nd4j-serving-onnx-session-" + UUID.randomUUID().toString()));
            env.retainReference();
        }

        sessionOptions = new SessionOptions();
        sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
        sessionOptions.SetIntraOpNumThreads(1);
        sessionOptions.retainReference();
        allocator = new AllocatorWithDefaultOptions();
        allocator.retainReference();
        bp = Loader.getPlatform().toLowerCase().startsWith("windows") ? new CharPointer(modelUri) : new BytePointer(modelUri);
        runOptions = new RunOptions();
        memoryInfo = MemoryInfo.CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
        session = new Session(env, bp, sessionOptions);
        //retain the session reference to prevent pre emptive release of the session.
        session.retainReference();

    }



    @Override
    public void close() {
        if(session != null) {
            session.close();
        }

        sessionOptions.releaseReference();
        allocator.releaseReference();
        runOptions.releaseReference();
    }


    /**
     * Execute the {@link #session}
     * using the given input {@link Map}
     * input
     * @param input the input map
     * @return a map of the names of the ndarrays
     */
    public Map exec(Map input) {
        long numInputNodes = session.GetInputCount();
        long numOutputNodes = session.GetOutputCount();

        PointerPointer inputNodeNames = new PointerPointer<>(numInputNodes);
        PointerPointer outputNodeNames = new PointerPointer<>(numOutputNodes);

        Value inputVal = new Value(numInputNodes);

        for (int i = 0; i < numInputNodes; i++) {
            BytePointer inputName = session.GetInputName(i, allocator.asOrtAllocator());
            inputNodeNames.put(i, inputName);
            INDArray arr = input.get(inputName.getString());
            Value inputTensor = getTensor(arr, memoryInfo);
            Preconditions.checkState(inputTensor.IsTensor(),"Input must be a tensor.");
            inputVal.position(i).put(inputTensor);
        }

        //reset position after iterating
        inputVal.position(0);



        for (int i = 0; i < numOutputNodes; i++) {
            BytePointer outputName = session.GetOutputName(i, allocator.asOrtAllocator());
            outputNodeNames.put(i, outputName);
        }

        ValueVector outputVector = session.Run(
                runOptions,
                inputNodeNames,
                inputVal,
                numInputNodes,
                outputNodeNames,
                numOutputNodes);

        outputVector.retainReference();
        Map ret = new LinkedHashMap<>();

        for (int i = 0; i < numOutputNodes; i++) {
            Value outValue = outputVector.get(i);
            outValue.retainReference();
            TypeInfo typeInfo = session.GetOutputTypeInfo(i);
            DataBuffer buffer = getDataBuffer(outValue);
            LongPointer longPointer = outValue.GetTensorTypeAndShapeInfo().GetShape();
            //shape info can be null
            if(longPointer != null) {
                long[] shape = new long[(int) longPointer.capacity()];
                longPointer.get(shape);
                ret.put((outputNodeNames.get(BytePointer.class, i)).getString(), Nd4j.create(buffer).reshape(shape));
            } else {
                ret.put((outputNodeNames.get(BytePointer.class, i)).getString(), Nd4j.create(buffer));

            }
        }

        return ret;


    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy