All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.onnxruntime.OnnxTensor Maven / Gradle / Ivy

/*
 * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
 * Licensed under the MIT License.
 */
package ai.onnxruntime;

import ai.onnxruntime.platform.Fp16Conversions;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Optional;
import java.util.logging.Logger;

/**
 * A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be
 * returned as outputs.
 */
public class OnnxTensor extends OnnxTensorLike {
  private static final Logger logger = Logger.getLogger(OnnxTensor.class.getName());

  /**
   * This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does
   * not go out of scope while the OnnxTensor exists.
   */
  private final Buffer buffer;

  /**
   * Denotes if the OnnxTensor made a copy of the buffer on construction (i.e. it may have the only
   * reference).
   */
  private final boolean ownsBuffer;

  OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info) {
    this(nativeHandle, allocatorHandle, info, null, false);
  }

  OnnxTensor(
      long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer, boolean ownsBuffer) {
    super(nativeHandle, allocatorHandle, info);
    this.buffer = buffer;
    this.ownsBuffer = ownsBuffer;
  }

  /**
   * Returns true if the buffer in this OnnxTensor was created on construction of this tensor, i.e.,
   * it is a copy of a user supplied buffer or array and may hold the only reference to that buffer.
   *
   * 

When this is true the backing buffer was copied from the user input, so users cannot mutate * the state of this buffer without first getting the reference via {@link #getBufferRef()}. * * @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is * a copy of a user buffer.) */ public boolean ownsBuffer() { return this.ownsBuffer; } /** * Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not * backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by * ORT) this method returns an empty {@link Optional}. * *

Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be * used to repeatedly update a single tensor for multiple different inferences without allocating * new tensors, though the inputs must remain the same size and shape. * *

Note: the tensor could refer to a contiguous range of elements in this buffer, not the whole * buffer. It is up to the user to manage this information by respecting the position and limit. * As a consequence, accessing this reference should be considered problematic when multiple * threads hold references to the buffer. * * @return A reference to the buffer. */ public Optional getBufferRef() { return Optional.ofNullable(buffer); } @Override public OnnxValueType getType() { return OnnxValueType.ONNX_TYPE_TENSOR; } /** * Either returns a boxed primitive if the Tensor is a scalar, or a multidimensional array of * primitives if it has multiple dimensions. * *

Java multidimensional arrays are quite slow for more than 2 dimensions, in that case it is * recommended you use the {@link java.nio.Buffer} extractors below (e.g., {@link * #getFloatBuffer}). * * @return A Java value. * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the * native code encountered an error. */ @Override public Object getValue() throws OrtException { checkClosed(); if (info.isScalar()) { switch (info.type) { case FLOAT: return getFloat(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value); case DOUBLE: return getDouble(OnnxRuntime.ortApiHandle, nativeHandle); case UINT8: case INT8: return getByte(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value); case INT16: return getShort(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value); case INT32: return getInt(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value); case INT64: return getLong(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value); case BOOL: return getBool(OnnxRuntime.ortApiHandle, nativeHandle); case STRING: return getString(OnnxRuntime.ortApiHandle, nativeHandle); case FLOAT16: return Fp16Conversions.fp16ToFloat( getShort(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value)); case BFLOAT16: return Fp16Conversions.bf16ToFloat( getShort(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value)); case UNKNOWN: default: throw new OrtException("Extracting the value of an invalid Tensor."); } } else { Object carrier = info.makeCarrier(); if (info.getNumElements() > 0) { // If the tensor has values copy them out getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); } if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { // We read the strings out from native code in a flat array and then reshape // to the desired output shape. return OrtUtil.reshape((String[]) carrier, info.shape); } else { return carrier; } } } @Override public String toString() { return "OnnxTensor(info=" + info.toString() + ",closed=" + closed + ")"; } /** * Closes the tensor, releasing its underlying memory (if it's not backed by an NIO buffer). If it * is backed by a buffer then the memory is released when the buffer is GC'd. */ @Override public synchronized void close() { if (!closed) { close(OnnxRuntime.ortApiHandle, nativeHandle); closed = true; } else { logger.warning("Closing an already closed tensor."); } } /** * Returns a copy of the underlying OnnxTensor as a ByteBuffer. * *

This method returns null if the OnnxTensor contains Strings as they are stored externally to * the OnnxTensor. * * @return A ByteBuffer copy of the OnnxTensor. */ public ByteBuffer getByteBuffer() { checkClosed(); if (info.type != OnnxJavaType.STRING) { ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle); ByteBuffer output = ByteBuffer.allocate(buffer.capacity()); output.put(buffer); output.rewind(); return output; } else { return null; } } /** * Returns a copy of the underlying OnnxTensor as a FloatBuffer if it can be losslessly converted * into a float (i.e. it's a float, fp16 or bf16), otherwise it returns null. * * @return A FloatBuffer copy of the OnnxTensor. */ public FloatBuffer getFloatBuffer() { checkClosed(); if (info.type == OnnxJavaType.FLOAT) { // if it's fp32 use the efficient copy. FloatBuffer buffer = getBuffer().asFloatBuffer(); FloatBuffer output = FloatBuffer.allocate(buffer.capacity()); output.put(buffer); output.rewind(); return output; } else if (info.type == OnnxJavaType.FLOAT16) { // if it's fp16 we need to copy it out by hand. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer); } else if (info.type == OnnxJavaType.BFLOAT16) { // if it's bf16 we need to copy it out by hand. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer); } else { return null; } } /** * Returns a copy of the underlying OnnxTensor as a DoubleBuffer if the underlying type is a * double, otherwise it returns null. * * @return A DoubleBuffer copy of the OnnxTensor. */ public DoubleBuffer getDoubleBuffer() { checkClosed(); if (info.type == OnnxJavaType.DOUBLE) { DoubleBuffer buffer = getBuffer().asDoubleBuffer(); DoubleBuffer output = DoubleBuffer.allocate(buffer.capacity()); output.put(buffer); output.rewind(); return output; } else { return null; } } /** * Returns a copy of the underlying OnnxTensor as a ShortBuffer if the underlying type is int16, * uint16, fp16 or bf16, otherwise it returns null. * * @return A ShortBuffer copy of the OnnxTensor. */ public ShortBuffer getShortBuffer() { checkClosed(); if ((info.type == OnnxJavaType.INT16) || (info.type == OnnxJavaType.FLOAT16) || (info.type == OnnxJavaType.BFLOAT16)) { ShortBuffer buffer = getBuffer().asShortBuffer(); ShortBuffer output = ShortBuffer.allocate(buffer.capacity()); output.put(buffer); output.rewind(); return output; } else { return null; } } /** * Returns a copy of the underlying OnnxTensor as an IntBuffer if the underlying type is int32 or * uint32, otherwise it returns null. * * @return An IntBuffer copy of the OnnxTensor. */ public IntBuffer getIntBuffer() { checkClosed(); if (info.type == OnnxJavaType.INT32) { IntBuffer buffer = getBuffer().asIntBuffer(); IntBuffer output = IntBuffer.allocate(buffer.capacity()); output.put(buffer); output.rewind(); return output; } else { return null; } } /** * Returns a copy of the underlying OnnxTensor as a LongBuffer if the underlying type is int64 or * uint64, otherwise it returns null. * * @return A LongBuffer copy of the OnnxTensor. */ public LongBuffer getLongBuffer() { checkClosed(); if (info.type == OnnxJavaType.INT64) { LongBuffer buffer = getBuffer().asLongBuffer(); LongBuffer output = LongBuffer.allocate(buffer.capacity()); output.put(buffer); output.rewind(); return output; } else { return null; } } /** * Wraps the OrtTensor pointer in a direct byte buffer of the native platform endian-ness. Unless * you really know what you're doing, you want this one rather than the native call {@link * OnnxTensor#getBuffer(long,long)}. * * @return A ByteBuffer wrapping the data. */ private ByteBuffer getBuffer() { return getBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder()); } /** * Wraps the OrtTensor pointer in a direct byte buffer. * * @param apiHandle The OrtApi pointer. * @param nativeHandle The OrtTensor pointer. * @return A ByteBuffer wrapping the data. */ private native ByteBuffer getBuffer(long apiHandle, long nativeHandle); private native float getFloat(long apiHandle, long nativeHandle, int onnxType) throws OrtException; private native double getDouble(long apiHandle, long nativeHandle) throws OrtException; private native byte getByte(long apiHandle, long nativeHandle, int onnxType) throws OrtException; private native short getShort(long apiHandle, long nativeHandle, int onnxType) throws OrtException; private native int getInt(long apiHandle, long nativeHandle, int onnxType) throws OrtException; private native long getLong(long apiHandle, long nativeHandle, int onnxType) throws OrtException; private native String getString(long apiHandle, long nativeHandle) throws OrtException; private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException; private native void getArray(long apiHandle, long nativeHandle, Object carrier) throws OrtException; private native void close(long apiHandle, long nativeHandle); /** * Create a Tensor from a Java primitive, primitive multidimensional array or String * multidimensional array. The shape is inferred from the object using reflection. The default * allocator is used. * *

Note: Java multidimensional arrays are not dense and this method requires traversing a large * number of pointers for high dimensional arrays. For types other than Strings it is recommended * to use one of the {@code createTensor} methods which accepts a {@link java.nio.Buffer}, e.g. * {@link #createTensor(OrtEnvironment, FloatBuffer, long[])} as those methods are zero copy to * transfer data into ORT when using direct buffers. * * @param env The current OrtEnvironment. * @param data The data to store in a tensor. * @return An OnnxTensor storing the data. * @throws OrtException If the onnx runtime threw an error. */ public static OnnxTensor createTensor(OrtEnvironment env, Object data) throws OrtException { return createTensor(env, env.defaultAllocator, data); } /** * Create a Tensor from a Java primitive, String, primitive multidimensional array or String * multidimensional array. The shape is inferred from the object using reflection. * * @param env The current OrtEnvironment. * @param allocator The allocator to use. * @param data The data to store in a tensor. * @return An OnnxTensor storing the data. * @throws OrtException If the onnx runtime threw an error. */ static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Object data) throws OrtException { if (!allocator.isClosed()) { TensorInfo info = TensorInfo.constructFromJavaArray(data); if (info.type == OnnxJavaType.STRING) { if (info.shape.length == 0) { return new OnnxTensor( createString(OnnxRuntime.ortApiHandle, allocator.handle, (String) data), allocator.handle, info); } else { return new OnnxTensor( createStringTensor( OnnxRuntime.ortApiHandle, allocator.handle, OrtUtil.flattenString(data), info.shape), allocator.handle, info); } } else { if (info.shape.length == 0) { data = OrtUtil.convertBoxedPrimitiveToArray(info.type, data); if (data == null) { throw new OrtException( "Failed to convert a boxed primitive to an array, this is an error with the ORT Java API, please report this message & stack trace. JavaType = " + info.type + ", object = " + data); } } return new OnnxTensor( createTensor( OnnxRuntime.ortApiHandle, allocator.handle, data, info.shape, info.onnxType.value), allocator.handle, info); } } else { throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator."); } } /** * Create a tensor from a flattened string array. * *

Requires the array to be flattened in row-major order. Uses the default allocator. * * @param env The current OrtEnvironment. * @param data The tensor data * @param shape the shape of the tensor * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ public static OnnxTensor createTensor(OrtEnvironment env, String[] data, long[] shape) throws OrtException { return createTensor(env, env.defaultAllocator, data, shape); } /** * Create a tensor from a flattened string array. * *

Requires the array to be flattened in row-major order. * * @param env The current OrtEnvironment. * @param allocator The allocator to use. * @param data The tensor data * @param shape the shape of the tensor * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, String[] data, long[] shape) throws OrtException { if (!allocator.isClosed()) { TensorInfo info = new TensorInfo( shape, OnnxJavaType.STRING, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); return new OnnxTensor( createStringTensor(OnnxRuntime.ortApiHandle, allocator.handle, data, shape), allocator.handle, info); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } } /** * Create an OnnxTensor backed by a direct FloatBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Uses the default allocator. * * @param env The current OrtEnvironment. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ public static OnnxTensor createTensor(OrtEnvironment env, FloatBuffer data, long[] shape) throws OrtException { return createTensor(env, env.defaultAllocator, data, shape); } /** * Create an OnnxTensor backed by a direct FloatBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. * * @param env The current OrtEnvironment. * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, FloatBuffer data, long[] shape) throws OrtException { if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.FLOAT; return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } } /** * Create an OnnxTensor backed by a direct DoubleBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Uses the default allocator. * * @param env The current OrtEnvironment. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ public static OnnxTensor createTensor(OrtEnvironment env, DoubleBuffer data, long[] shape) throws OrtException { return createTensor(env, env.defaultAllocator, data, shape); } /** * Create an OnnxTensor backed by a direct DoubleBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. * * @param env The current OrtEnvironment. * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, DoubleBuffer data, long[] shape) throws OrtException { if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.DOUBLE; return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } } /** * Create an OnnxTensor backed by a direct ByteBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Uses the default allocator. Tells the runtime it's {@link OnnxJavaType#INT8}. * * @param env The current OrtEnvironment. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ public static OnnxTensor createTensor(OrtEnvironment env, ByteBuffer data, long[] shape) throws OrtException { return createTensor(env, env.defaultAllocator, data, shape); } /** * Create an OnnxTensor backed by a direct ByteBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Tells the runtime it's {@link OnnxJavaType#INT8}. * * @param env The current OrtEnvironment. * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, ByteBuffer data, long[] shape) throws OrtException { return createTensor(env, allocator, data, shape, OnnxJavaType.INT8); } /** * Create an OnnxTensor backed by a direct ByteBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Uses the default allocator. Tells the runtime it's the specified type. * * @param env The current OrtEnvironment. * @param data The tensor data. * @param shape The shape of tensor. * @param type The type to use for the byte buffer elements. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ public static OnnxTensor createTensor( OrtEnvironment env, ByteBuffer data, long[] shape, OnnxJavaType type) throws OrtException { return createTensor(env, env.defaultAllocator, data, shape, type); } /** * Create an OnnxTensor backed by a direct ByteBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Tells the runtime it's the specified type. * * @param env The current OrtEnvironment. * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. * @param type The type to use for the byte buffer elements. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, ByteBuffer data, long[] shape, OnnxJavaType type) throws OrtException { if (!allocator.isClosed()) { return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } } /** * Create an OnnxTensor backed by a direct ShortBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Uses the default allocator. * * @param env The current OrtEnvironment. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long[] shape) throws OrtException { return createTensor(env, env.defaultAllocator, data, shape); } /** * Create an OnnxTensor backed by a direct ShortBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. * * @param env The current OrtEnvironment. * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape) throws OrtException { if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.INT16; return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } } /** * Create an OnnxTensor backed by a direct IntBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Uses the default allocator. * * @param env The current OrtEnvironment. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ public static OnnxTensor createTensor(OrtEnvironment env, IntBuffer data, long[] shape) throws OrtException { return createTensor(env, env.defaultAllocator, data, shape); } /** * Create an OnnxTensor backed by a direct IntBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. * * @param env The current OrtEnvironment. * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, IntBuffer data, long[] shape) throws OrtException { if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.INT32; return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } } /** * Create an OnnxTensor backed by a direct LongBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Uses the default allocator. * * @param env The current OrtEnvironment. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ public static OnnxTensor createTensor(OrtEnvironment env, LongBuffer data, long[] shape) throws OrtException { return createTensor(env, env.defaultAllocator, data, shape); } /** * Create an OnnxTensor backed by a direct LongBuffer. The buffer should be in nativeOrder. * *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime * of the tensor. Uses the supplied allocator. * * @param env The current OrtEnvironment. * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, LongBuffer data, long[] shape) throws OrtException { if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.INT64; return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } } /** * Creates a tensor by wrapping the data in a direct byte buffer and passing it to JNI. * *

Throws IllegalStateException if the buffer is too large to create a direct byte buffer copy, * which is more than approximately (Integer.MAX_VALUE - 5) / type.size elements. * * @param type The buffer type. * @param allocator The OrtAllocator. * @param data The data. * @param shape The tensor shape. * @return An OnnxTensor instance. * @throws OrtException If the create call failed. */ private static OnnxTensor createTensor( OnnxJavaType type, OrtAllocator allocator, Buffer data, long[] shape) throws OrtException { OrtUtil.BufferTuple tuple = OrtUtil.prepareBuffer(data, type); TensorInfo info = TensorInfo.constructFromBuffer(tuple.data, shape, type); return new OnnxTensor( createTensorFromBuffer( OnnxRuntime.ortApiHandle, allocator.handle, tuple.data, tuple.pos, tuple.byteSize, shape, info.onnxType.value), allocator.handle, info, tuple.data, tuple.isCopy); } private static native long createTensor( long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType) throws OrtException; private static native long createTensorFromBuffer( long apiHandle, long allocatorHandle, Buffer data, int bufferPos, long bufferSize, long[] shape, int onnxType) throws OrtException; private static native long createString(long apiHandle, long allocatorHandle, String data) throws OrtException; private static native long createStringTensor( long apiHandle, long allocatorHandle, Object[] data, long[] shape) throws OrtException; }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy