All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.pytorch.jni.IValue Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.pytorch.jni;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.util.NativeResource;

import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * A class represent a PyTorch {@code IValue} data.
 *
 * 

DJL doesn't support creating nested IValue. */ public class IValue extends NativeResource { IValue(long handle) { super(handle); } /** * Returns the type of the IValue. * * @return the type of the IValue */ public String getType() { return PyTorchLibrary.LIB.iValueGetType(getHandle()); } /** * Returns if the IValue is a {@code Tensor} type. * * @return if the IValue is a Tensor type */ public boolean isTensor() { return PyTorchLibrary.LIB.iValueIsTensor(getHandle()); } /** * Returns if the IValue is a {@code boolean} type. * * @return if the IValue is a boolean type */ public boolean isBoolean() { return PyTorchLibrary.LIB.iValueIsBool(getHandle()); } /** * Returns if the IValue is a {@code long} type. * * @return if the IValue is a long type */ public boolean isLong() { return PyTorchLibrary.LIB.iValueIsLong(getHandle()); } /** * Returns if the IValue is a {@code double} type. * * @return if the IValue is a double type */ public boolean isDouble() { return PyTorchLibrary.LIB.iValueIsDouble(getHandle()); } /** * Returns if the IValue is a {@code String} type. * * @return if the IValue is a String type */ public boolean isString() { return PyTorchLibrary.LIB.iValueIsString(getHandle()); } /** * Returns if the IValue is a {@code boolean[]} type. * * @return if the IValue is a boolean[] type */ public boolean isBooleanList() { return PyTorchLibrary.LIB.iValueIsBoolList(getHandle()); } /** * Returns if the IValue is a {@code long[]} type. * * @return if the IValue is a long[] type */ public boolean isLongList() { return PyTorchLibrary.LIB.iValueIsLongList(getHandle()); } /** * Returns if the IValue is a {@code double[]} type. * * @return if the IValue is a double[] type */ public boolean isDoubleList() { return PyTorchLibrary.LIB.iValueIsDoubleList(getHandle()); } /** * Returns if the IValue is a {@code IValue[]} type. * *

The elements in the array must have the same type. * * @return if the IValue is a IValue[] type */ public boolean isTensorList() { return PyTorchLibrary.LIB.iValueIsTensorList(getHandle()); } /** * Returns if the IValue is a {@code IValue[]} type. * *

The elements in the array must have the same type. * * @return if the IValue is a IValue[] type */ public boolean isList() { return PyTorchLibrary.LIB.iValueIsList(getHandle()); } /** * Returns if the IValue is a {@code Map<String, V>} type. * * @return if the IValue is a Map<String, V> type */ public boolean isMap() { return PyTorchLibrary.LIB.iValueIsMap(getHandle()); } /** * Returns if the IValue is a tuple type. * * @return if the IValue is a tuple type */ public boolean isTuple() { return PyTorchLibrary.LIB.iValueIsTuple(getHandle()); } /** * Creates a new {@code IValue} of type {@code PtNDArray}. * * @param array the NDArray * @return a new {@code IValue} of type {@code PtNDArray} */ public static IValue from(PtNDArray array) { if (array.getDataType() == DataType.STRING) { Shape shape = array.getShape(); String[] strs = array.toStringArray(); if (shape.isScalar()) { return from(strs[0]); } IValue[] list = new IValue[strs.length]; PtNDManager manager = array.getManager(); for (int i = 0; i < strs.length; i++) { IValue ivalue = from(strs[i]); manager.attachUncappedInternal(ivalue.getUid(), ivalue); list[i] = ivalue; } return listFrom(list); } return new IValue(PyTorchLibrary.LIB.iValueFromTensor(array.getHandle())); } /** * Creates a new {@code IValue} of type {@code boolean}. * * @param value the boolean value * @return a new {@code IValue} of type {@code boolean} */ public static IValue from(boolean value) { return new IValue(PyTorchLibrary.LIB.iValueFromBool(value)); } /** * Creates a new {@code IValue} of type {@code long}. * * @param value the long value * @return a new {@code IValue} of type {@code long} */ public static IValue from(long value) { return new IValue(PyTorchLibrary.LIB.iValueFromLong(value)); } /** * Creates a new {@code IValue} of type {@code double}. * * @param value the double value * @return a new {@code IValue} of type {@code double} */ public static IValue from(double value) { return new IValue(PyTorchLibrary.LIB.iValueFromDouble(value)); } /** * Creates a new {@code IValue} of type {@code String}. * * @param value the String value * @return a new {@code IValue} of type {@code String} */ public static IValue from(String value) { return new IValue(PyTorchLibrary.LIB.iValueFromString(value)); } /** * Creates a new {@code IValue} of type {@code boolean[]}. * * @param list the boolean[] value * @return a new {@code IValue} of type {@code boolean[]} */ public static IValue listFrom(boolean... list) { return new IValue(PyTorchLibrary.LIB.iValueFromBoolList(list)); } /** * Creates a new {@code IValue} of type {@code long[]}. * * @param list the long[] value * @return a new {@code IValue} of type {@code long[]} */ public static IValue listFrom(long... list) { return new IValue(PyTorchLibrary.LIB.iValueFromLongList(list)); } /** * Creates a new {@code IValue} of type {@code double[]}. * * @param list the double[] value * @return a new {@code IValue} of type {@code double[]} */ public static IValue listFrom(double... list) { return new IValue(PyTorchLibrary.LIB.iValueFromDoubleList(list)); } /** * Creates a new {@code IValue} of type {@code NDArray[]}. * * @param list the NDArray[] value * @return a new {@code IValue} of type {@code NDArray[]} */ public static IValue listFrom(PtNDArray... list) { long[] tensors = Arrays.stream(list).mapToLong(PtNDArray::getHandle).toArray(); return new IValue(PyTorchLibrary.LIB.iValueFromTensorList(tensors)); } /** * Creates a new {@code IValue} of type {@code NDArray[]}. * * @param list the NDArray[] value * @return a new {@code IValue} of type {@code NDArray[]} */ public static IValue listFrom(IValue... list) { if (list.length == 0) { throw new IllegalArgumentException("Empty IValue list is not supported."); } long[] tensors = Arrays.stream(list).mapToLong(IValue::getHandle).toArray(); return new IValue(PyTorchLibrary.LIB.iValueFromList(tensors)); } /** * Creates a new {@code IValue} of type {@code NDArray[]}. * * @param list the NDArray[] value * @return a new {@code IValue} of type {@code NDArray[]} */ public static IValue tupleFrom(IValue... list) { long[] tensors = Arrays.stream(list).mapToLong(IValue::getHandle).toArray(); return new IValue(PyTorchLibrary.LIB.iValueFromTuple(tensors)); } /** * Creates a new {@code IValue} of type {@code Map[String, PtNDArray]}. * * @param map the Map[String, IValue] value * @return a new {@code IValue} of type {@code Map[String, PtNDArray]} */ public static IValue stringMapFrom(Map map) { String[] keys = new String[map.size()]; long[] handles = new long[map.size()]; int i = 0; for (Map.Entry entry : map.entrySet()) { keys[i] = entry.getKey(); handles[i] = entry.getValue().getHandle(); ++i; } return new IValue(PyTorchLibrary.LIB.iValueFromStringMap(keys, handles)); } /** * Creates a new {@code IValue} of type {@code Map[String, IValue]}. * * @param map the Map[String, IValue] value * @return a new {@code IValue} of type {@code Map[String, IValue]} */ public static IValue stringIValueMapFrom(Map map) { String[] keys = new String[map.size()]; long[] handles = new long[map.size()]; int i = 0; for (Map.Entry entry : map.entrySet()) { keys[i] = entry.getKey(); handles[i] = entry.getValue().getHandle(); ++i; } return new IValue(PyTorchLibrary.LIB.iValueFromStringIValueMap(keys, handles)); } /** * Returns the {@code boolean} value of this IValue. * * @return the boolean value of this IValue */ public boolean toBoolean() { return PyTorchLibrary.LIB.iValueToBool(getHandle()); } /** * Returns the {@code long} value of this IValue. * * @return the long value of this IValue */ public long toLong() { return PyTorchLibrary.LIB.iValueToLong(getHandle()); } /** * Returns the {@code double} value of this IValue. * * @return the double value of this IValue */ public double toDouble() { return PyTorchLibrary.LIB.iValueToDouble(getHandle()); } /** * Returns the {@code String} value of this IValue. * * @return the String value of this IValue */ public String toStringValue() { return PyTorchLibrary.LIB.iValueToString(getHandle()); } /** * Returns the {@code boolean[]} value of this IValue. * * @return the boolean[] value of this IValue */ public boolean[] toBooleanArray() { return PyTorchLibrary.LIB.iValueToBoolList(getHandle()); } /** * Returns the {@code long[]} value of this IValue. * * @return the long[] value of this IValue */ public long[] toLongArray() { return PyTorchLibrary.LIB.iValueToLongList(getHandle()); } /** * Returns the {@code double[]} value of this IValue. * * @return the double[] value of this IValue */ public double[] toDoubleArray() { return PyTorchLibrary.LIB.iValueToDoubleList(getHandle()); } /** * Returns the {@code NDArray} value of this IValue. * * @param manager the {@code NDManager} to create the NDArray * @return the NDArray value of this IValue */ public PtNDArray toTensor(PtNDManager manager) { return new PtNDArray(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); } /** * Returns the {@code NDArray[]} value of this IValue. * * @param manager the NDManager to create NDArray * @return the NDArray[] value of this IValue */ public PtNDArray[] toTensorArray(PtNDManager manager) { long[] handles = PyTorchLibrary.LIB.iValueToTensorList(getHandle()); PtNDArray[] ret = new PtNDArray[handles.length]; for (int i = 0; i < ret.length; ++i) { ret[i] = new PtNDArray(manager, handles[i]); } return ret; } /** * Returns the {@code IValue[]} value of this IValue list. * * @return the IValue[] value of this IValue list */ public IValue[] toIValueArray() { long[] handles = PyTorchLibrary.LIB.iValueToIValueList(getHandle()); IValue[] ret = new IValue[handles.length]; for (int i = 0; i < ret.length; ++i) { ret[i] = new IValue(handles[i]); } return ret; } /** * Returns the {@code Map<String, IValue>} value of this IValue. * * @return the Map<String, IValue> value of this IValue */ public Map toIValueMap() { long[] handles = PyTorchLibrary.LIB.iValueToMap(getHandle()); Map map = new ConcurrentHashMap<>(); for (int i = 0; i < handles.length; i += 2) { IValue key = new IValue(handles[i]); map.put(key.toStringValue(), new IValue(handles[i + 1])); key.close(); } return map; } /** * Returns the {@code Map<String, IValue>} value of this IValue. * * @return the Map<String, IValue> value of this IValue */ public IValue[] toIValueTuple() { long[] handles = PyTorchLibrary.LIB.iValueToIValueTuple(getHandle()); IValue[] ret = new IValue[handles.length]; for (int i = 0; i < ret.length; ++i) { ret[i] = new IValue(handles[i]); } return ret; } /** * Returns the {@code NDList} value of this IValue. * * @param manager the NDManager to create NDArray * @return the {@code NDList} value of this IValue */ public NDList toNDList(PtNDManager manager) { if (isTensor()) { return new NDList(toTensor(manager)); } else if (isTensorList()) { return new NDList(toTensorArray(manager)); } else if (isMap()) { // Only allows one level type of map NDList list = new NDList(); Map map = toIValueMap(); for (Map.Entry entry : map.entrySet()) { IValue iv = entry.getValue(); if (!iv.isTensor()) { throw new UnsupportedOperationException("Only one level of map is supported."); } PtNDArray value = entry.getValue().toTensor(manager); value.setName(entry.getKey()); list.add(value); iv.close(); } return list; } else if (isList()) { NDList list = new NDList(); for (IValue ivalue : toIValueArray()) { list.addAll(ivalue.toNDList(manager)); ivalue.close(); } return list; } else if (isTuple()) { NDList list = new NDList(); for (IValue ivalue : toIValueTuple()) { list.addAll(ivalue.toNDList(manager)); ivalue.close(); } return list; } else if (isString()) { return new NDList(manager.create(toStringValue())); } throw new UnsupportedOperationException("Unsupported IValue type."); } /** {@inheritDoc} */ @Override public void close() { Long pointer = handle.getAndSet(null); if (pointer != null) { PyTorchLibrary.LIB.torchDeleteIValue(pointer); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy