All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.onnxruntime.TensorInfo Maven / Gradle / Ivy

/*
 * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
 * Licensed under the MIT License.
 */
package ai.onnxruntime;

import java.lang.reflect.Array;
import java.nio.Buffer;
import java.util.Arrays;

/** Describes an {@link OnnxTensor}, including it's size, shape and element type. */
public class TensorInfo implements ValueInfo {

  /** Maximum number of dimensions supported by the Java interface methods. */
  public static final int MAX_DIMENSIONS = 8;

  /** The native element types supported by the ONNX runtime. */
  public enum OnnxTensorType {
    ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED(0),
    ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8(1), // maps to c type uint8_t
    ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8(2), // maps to c type int8_t
    ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16(3), // maps to c type uint16_t
    ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16(4), // maps to c type int16_t
    ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32(5), // maps to c type uint32_t
    ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32(6), // maps to c type int32_t
    ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64(7), // maps to c type uint64_t
    ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64(8), // maps to c type int64_t
    ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16(9), // stored as a uint16_t
    ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT(10), // maps to c type float
    ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE(11), // maps to c type double
    ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING(12), // maps to c++ type std::string
    ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL(13),
    ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64(
        14), // complex with float32 real and imaginary components
    ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128(
        15), // complex with float64 real and imaginary components
    ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16(
        16); // Non-IEEE floating-point format based on IEEE754 single-precision

    /** The int id on the native side. */
    public final int value;

    private static final OnnxTensorType[] values = new OnnxTensorType[17];

    static {
      for (OnnxTensorType ot : OnnxTensorType.values()) {
        values[ot.value] = ot;
      }
    }

    OnnxTensorType(int value) {
      this.value = value;
    }

    /**
     * Maps from the C API's int enum to the Java enum.
     *
     * @param value The index of the Java enum.
     * @return The Java enum.
     */
    public static OnnxTensorType mapFromInt(int value) {
      if ((value > 0) && (value < values.length)) {
        return values[value];
      } else {
        return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
      }
    }

    /**
     * Maps a OnnxJavaType into the appropriate native element type.
     *
     * @param type The type of the Java input/output.
     * @return The native element type.
     */
    public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
      switch (type) {
        case FLOAT:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
        case DOUBLE:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
        case INT8:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
        case UINT8:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
        case INT16:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
        case INT32:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
        case INT64:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
        case BOOL:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
        case STRING:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
        case UNKNOWN:
        default:
          return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
      }
    }
  }

  final long[] shape;

  /** The Java type of this tensor. */
  public final OnnxJavaType type;

  /** The native type of this tensor. */
  public final OnnxTensorType onnxType;

  /**
   * Constructs a TensorInfo with the specified shape, Java type and native type.
   *
   * @param shape The tensor shape.
   * @param type The Java type.
   * @param onnxType The native type.
   */
  TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) {
    this.shape = shape;
    this.type = type;
    this.onnxType = onnxType;
  }

  /**
   * Get a copy of the tensor's shape.
   *
   * @return A copy of the tensor's shape.
   */
  public long[] getShape() {
    return Arrays.copyOf(shape, shape.length);
  }

  @Override
  public String toString() {
    return "TensorInfo(javaType="
        + type.toString()
        + ",onnxType="
        + onnxType.toString()
        + ",shape="
        + Arrays.toString(shape)
        + ")";
  }

  /**
   * Returns true if the shape represents a scalar value (i.e. it has zero dimensions).
   *
   * @return True if the shape is a scalar.
   */
  public boolean isScalar() {
    return shape.length == 0;
  }

  /**
   * Checks that the shape of this tensor info is valid (i.e. positive and within the bounds of a
   * Java int).
   *
   * @return True if the shape is valid.
   */
  private boolean validateShape() {
    return OrtUtil.validateShape(shape);
  }

  /**
   * Constructs an array the right shape and type to hold this tensor.
   *
   * 

Note for String tensors, this carrier is a single dimensional array with enough space for * all elements as that's the expected format of the native code. It can be reshaped to the * correct shape using {@link OrtUtil#reshape(String[],long[])}. * * @return A multidimensional array of the appropriate primitive type (or String). * @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is * greater than an int). */ public Object makeCarrier() throws OrtException { if (!validateShape()) { throw new OrtException( "This tensor is not representable in Java, it's too big - shape = " + Arrays.toString(shape)); } switch (type) { case FLOAT: return OrtUtil.newFloatArray(shape); case DOUBLE: return OrtUtil.newDoubleArray(shape); case UINT8: case INT8: return OrtUtil.newByteArray(shape); case INT16: return OrtUtil.newShortArray(shape); case INT32: return OrtUtil.newIntArray(shape); case INT64: return OrtUtil.newLongArray(shape); case BOOL: return OrtUtil.newBooleanArray(shape); case STRING: return new String[(int) OrtUtil.elementCount(shape)]; case UNKNOWN: throw new OrtException("Can't construct a carrier for an invalid type."); default: throw new OrtException("Unsupported type - " + type); } } /** * Constructs a TensorInfo from the supplied multidimensional Java array, used to allocate the * appropriate amount of native memory. * * @param obj The object to inspect. * @return A TensorInfo which can be used to make the right size Tensor. * @throws OrtException If the supplied Object isn't an array, or is an invalid type. */ public static TensorInfo constructFromJavaArray(Object obj) throws OrtException { Class objClass = obj.getClass(); // Check if it's an array or a scalar. if (!objClass.isArray()) { // Check if it's a valid non-array type OnnxJavaType javaType = OnnxJavaType.mapFromClass(objClass); if (javaType == OnnxJavaType.UNKNOWN) { throw new OrtException("Cannot convert " + objClass + " to a OnnxTensor."); } else { // scalar primitive return new TensorInfo(new long[0], javaType, OnnxTensorType.mapFromJavaType(javaType)); } } // Figure out base type and number of dimensions. int dimensions = 0; while (objClass.isArray()) { objClass = objClass.getComponentType(); dimensions++; } if (!objClass.isPrimitive() && !objClass.equals(String.class)) { throw new OrtException("Cannot create an OnnxTensor from a base type of " + objClass); } else if (dimensions > MAX_DIMENSIONS) { throw new OrtException( "Cannot create an OnnxTensor with more than " + MAX_DIMENSIONS + " dimensions. Found " + dimensions + " dimensions."); } OnnxJavaType javaType = OnnxJavaType.mapFromClass(objClass); // Now we extract the shape and validate that the java array is rectangular (i.e. not ragged). // this is pretty nasty as we have to look at every object array recursively. // Thanks Java! long[] shape = new long[dimensions]; extractShape(shape, 0, obj); return new TensorInfo(shape, javaType, OnnxTensorType.mapFromJavaType(javaType)); } /** * Constructs a TensorInfo from the supplied byte buffer. * * @param buffer The buffer to inspect. * @param shape The shape of the tensor. * @param type The Java type. * @return A TensorInfo for a tensor. * @throws OrtException If the supplied buffer doesn't match the shape. */ public static TensorInfo constructFromBuffer(Buffer buffer, long[] shape, OnnxJavaType type) throws OrtException { if ((type == OnnxJavaType.STRING) || (type == OnnxJavaType.UNKNOWN)) { throw new OrtException("Cannot create a tensor from a string or unknown buffer."); } long elementCount = OrtUtil.elementCount(shape); long bufferRemaining = buffer.remaining(); if (elementCount != bufferRemaining) { throw new OrtException( "Shape " + Arrays.toString(shape) + ", requires " + elementCount + " elements but the buffer has " + bufferRemaining + " elements."); } return new TensorInfo( Arrays.copyOf(shape, shape.length), type, OnnxTensorType.mapFromJavaType(type)); } /** * Extracts the shape from a multidimensional array. Checks to see if the array is ragged or not. * * @param shape The shape array to write to. * @param curDim The current dimension to check. * @param obj The multidimensional array to inspect. * @throws OrtException If the array has a zero dimension, or is ragged. */ private static void extractShape(long[] shape, int curDim, Object obj) throws OrtException { if (shape.length != curDim) { int curLength = Array.getLength(obj); if (curLength == 0) { throw new OrtException( "Supplied array has a zero dimension at " + curDim + ", all dimensions must be positive"); } else if (shape[curDim] == 0L) { shape[curDim] = curLength; } else if (shape[curDim] != curLength) { throw new OrtException( "Supplied array is ragged, expected " + shape[curDim] + ", found " + curLength); } for (int i = 0; i < curLength; i++) { extractShape(shape, curDim + 1, Array.get(obj, i)); } } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy