All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.tensorflow.Tensor Maven / Gradle / Ivy

There is a newer version: 1.15.0
Show newest version
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow;

import java.lang.reflect.Array;
import java.nio.Buffer;
import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;

/**
 * A statically typed multi-dimensional array whose elements are of a type described by T.
 *
 * 

Instances of a Tensor are not thread-safe. * *

WARNING: Resources consumed by the Tensor object must be explicitly freed by * invoking the {@link #close()} method when the object is no longer needed. For example, using a * try-with-resources block: * *

{@code
 * try (Tensor t = Tensor.create(...)) {
 *   doSomethingWith(t);
 * }
 * }
*/ public final class Tensor implements AutoCloseable { /** * Creates a Tensor from a Java object. * *

A {@code Tensor} is a multi-dimensional array of elements of a limited set of types. Not all * Java objects can be converted to a {@code Tensor}. In particular, the argument {@code obj} must * be either a primitive (float, double, int, long, boolean, byte) or a multi-dimensional array of * one of those primitives. The argument {@code type} specifies how to interpret the first * argument as a TensorFlow type. For example: * *

{@code
   * // Valid: A 64-bit integer scalar.
   * Tensor s = Tensor.create(42L, Long.class);
   *
   * // Valid: A 3x2 matrix of floats.
   * float[][] matrix = new float[3][2];
   * Tensor m = Tensor.create(matrix, Float.class);
   *
   * // Invalid: Will throw an IllegalArgumentException as an arbitrary Object
   * // does not fit into the TensorFlow type system.
   * Tensor o = Tensor.create(new Object())
   *
   * // Invalid: Will throw an IllegalArgumentException since there are
   * // a differing number of elements in each row of this 2-D array.
   * int[][] twoD = new int[2][];
   * twoD[0] = new int[1];
   * twoD[1] = new int[2];
   * Tensor x = Tensor.create(twoD, Integer.class);
   * }
* * {@link String}-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can * be initialized from arrays of {@code byte[]} elements. For example: * *
{@code
   * // Valid: A String tensor.
   * Tensor s = Tensor.create(new byte[]{1, 2, 3}, String.class);
   *
   * // Java Strings will need to be encoded into a byte-sequence.
   * String mystring = "foo";
   * Tensor s = Tensor.create(mystring.getBytes("UTF-8"), String.class);
   *
   * // Valid: Matrix of String tensors.
   * // Each element might have a different length.
   * byte[][][] matrix = new byte[2][2][];
   * matrix[0][0] = "this".getBytes("UTF-8");
   * matrix[0][1] = "is".getBytes("UTF-8");
   * matrix[1][0] = "a".getBytes("UTF-8");
   * matrix[1][1] = "matrix".getBytes("UTF-8");
   * Tensor m = Tensor.create(matrix, String.class);
   * }
* * @param obj The object to convert to a {@code Tensor}. Note that whether it is compatible * with the type T is not checked by the type system. For type-safe creation of tensors, use * {@link Tensors}. * @param type The class object representing the type T. * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type * system. */ @SuppressWarnings("unchecked") public static Tensor create(Object obj, Class type) { DataType dtype = DataType.fromClass(type); if (!objectCompatWithType(obj, dtype)) { throw new IllegalArgumentException( "DataType of object does not match T (expected " + dtype + ", got " + dataTypeOf(obj) + ")"); } return (Tensor) create(obj, dtype); } /** * Creates a tensor from an object whose class is inspected to figure out what the underlying data * type should be. * * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type * system. */ public static Tensor create(Object obj) { return create(obj, dataTypeOf(obj)); } /** * Create a Tensor of data type {@code dtype} from a Java object. Requires the parameter {@code T} * to match {@code type}, but this condition is not checked. * * @param obj the object supplying the tensor data. * @param dtype the data type of the tensor to create. It must be compatible with the run-time * type of the object. * @return the new tensor */ private static Tensor create(Object obj, DataType dtype) { @SuppressWarnings("rawtypes") Tensor t = new Tensor(dtype); t.shapeCopy = new long[numDimensions(obj, dtype)]; fillShape(obj, 0, t.shapeCopy); if (t.dtype != DataType.STRING) { int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy); t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize); setValue(t.nativeHandle, obj); } else if (t.shapeCopy.length != 0) { t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj); } else { t.nativeHandle = allocateScalarBytes((byte[]) obj); } return t; } /** * Create a {@link Integer} Tensor with data from the given buffer. * *

Creates a Tensor with the given shape by copying elements from the buffer (starting from its * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this * method. * * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Tensor create(long[] shape, IntBuffer data) { Tensor t = allocateForBuffer(DataType.INT32, shape, data.remaining()); t.buffer().asIntBuffer().put(data); return t; } /** * Create a {@link Float} Tensor with data from the given buffer. * *

Creates a Tensor with the given shape by copying elements from the buffer (starting from its * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this * method. * * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Tensor create(long[] shape, FloatBuffer data) { Tensor t = allocateForBuffer(DataType.FLOAT, shape, data.remaining()); t.buffer().asFloatBuffer().put(data); return t; } /** * Create a {@link Double} Tensor with data from the given buffer. * *

Creates a Tensor with the given shape by copying elements from the buffer (starting from its * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this * method. * * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Tensor create(long[] shape, DoubleBuffer data) { Tensor t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining()); t.buffer().asDoubleBuffer().put(data); return t; } /** * Create an {@link Long} Tensor with data from the given buffer. * *

Creates a Tensor with the given shape by copying elements from the buffer (starting from its * current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a * 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this * method. * * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Tensor create(long[] shape, LongBuffer data) { Tensor t = allocateForBuffer(DataType.INT64, shape, data.remaining()); t.buffer().asLongBuffer().put(data); return t; } /** * Create a Tensor of any type with data from the given buffer. * *

Creates a Tensor with the provided shape of any type where the tensor's data has been * encoded into {@code data} as per the specification of the TensorFlow C * API. * * @param the tensor element type * @param type the tensor element type, represented as a class object. * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ public static Tensor create(Class type, long[] shape, ByteBuffer data) { @SuppressWarnings("unchecked") Tensor ret = (Tensor) create(DataType.fromClass(type), shape, data); return ret; } /** * Creates a Tensor of any type with data from the given buffer. * *

Creates a Tensor with the provided shape of any type where the tensor's data has been * encoded into {@code data} as per the specification of the TensorFlow C * API. * * @param The tensor element type * @param type the tensor element type, specified as a DataType. This must agree with T. * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ private static Tensor create(DataType dtype, long[] shape, ByteBuffer data) { int nremaining = 0; if (dtype != DataType.STRING) { int elemBytes = elemByteSize(dtype); if (data.remaining() % elemBytes != 0) { throw new IllegalArgumentException( String.format( "ByteBuffer with %d bytes is not compatible with a %s Tensor (%d bytes/element)", data.remaining(), dtype.toString(), elemBytes)); } nremaining = data.remaining() / elemBytes; } else { nremaining = data.remaining(); } Tensor t = allocateForBuffer(dtype, shape, nremaining); t.buffer().put(data); return t; } /** * Returns this Tensor object with the type {@code Tensor}. This method is useful when given a * value of type {@code Tensor}. * * @param type any (non-null) array of the correct type. * @throws IllegalArgumentException if the actual data type of this object does not match the type * {@code U}. */ @SuppressWarnings("unchecked") public Tensor expect(Class type) { DataType dt = DataType.fromClass(type); if (!dt.equals(dtype)) { throw new IllegalArgumentException( "Cannot cast from tensor of " + dtype + " to tensor of " + dt); } return ((Tensor) this); } // Helper function to allocate a Tensor for the create() methods that create a Tensor from // a java.nio.Buffer. // Requires: dataType matches T private static Tensor allocateForBuffer(DataType dataType, long[] shape, int nBuffered) { final int nflattened = numElements(shape); int nbytes = 0; if (dataType != DataType.STRING) { if (nBuffered != nflattened) { throw incompatibleBuffer(nBuffered, shape); } nbytes = nflattened * elemByteSize(dataType); } else { // DT_STRING tensor encoded in a ByteBuffer. nbytes = nBuffered; } Tensor t = new Tensor(dataType); t.shapeCopy = Arrays.copyOf(shape, shape.length); t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes); return t; } /** * Release resources associated with the Tensor. * *

WARNING:If not invoked, memory will be leaked. * *

The Tensor object is no longer usable after {@code close} returns. */ @Override public void close() { if (nativeHandle != 0) { delete(nativeHandle); nativeHandle = 0; } } /** Returns the {@link DataType} of elements stored in the Tensor. */ public DataType dataType() { return dtype; } /** * Returns the number of dimensions (sometimes referred to as rank) of the Tensor. * *

Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc. */ public int numDimensions() { return shapeCopy.length; } /** Returns the size, in bytes, of the tensor data. */ public int numBytes() { return buffer().remaining(); } /** Returns the number of elements in a flattened (1-D) view of the tensor. */ public int numElements() { return numElements(shapeCopy); } /** * Returns the shape of * the Tensor, i.e., the sizes of each dimension. * * @return an array where the i-th element is the size of the i-th dimension of the tensor. */ public long[] shape() { return shapeCopy; } /** * Returns the value in a scalar {@link Float} tensor. * * @throws IllegalArgumentException if the Tensor does not represent a float scalar. */ public float floatValue() { return scalarFloat(nativeHandle); } /** * Returns the value in a scalar {@link Double} tensor. * * @throws IllegalArgumentException if the Tensor does not represent a double scalar. */ public double doubleValue() { return scalarDouble(nativeHandle); } /** * Returns the value in a scalar {@link Integer} tensor. * * @throws IllegalArgumentException if the Tensor does not represent a int scalar. */ public int intValue() { return scalarInt(nativeHandle); } /** * Returns the value in a scalar {@link Long} tensor. * * @throws IllegalArgumentException if the Tensor does not represent a long scalar. */ public long longValue() { return scalarLong(nativeHandle); } /** * Returns the value in a scalar {@link Boolean} tensor. * * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. */ public boolean booleanValue() { return scalarBoolean(nativeHandle); } /** * Returns the value in a scalar {@link String} tensor. * * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. */ public byte[] bytesValue() { return scalarBytes(nativeHandle); } /** * Copies the contents of the tensor to {@code dst} and returns {@code dst}. * *

For non-scalar tensors, this method copies the contents of the underlying tensor to a Java * array. For scalar tensors, use one of {@link #bytesValue()}, {@link #floatValue()}, {@link * #doubleValue()}, {@link #intValue()}, {@link #longValue()} or {@link #booleanValue()} instead. * The type and shape of {@code dst} must be compatible with the tensor. For example: * *

{@code
   * int matrix[2][2] = {{1,2},{3,4}};
   * try(Tensor t = Tensor.create(matrix)) {
   *   // Succeeds and prints "3"
   *   int[][] copy = new int[2][2];
   *   System.out.println(t.copyTo(copy)[1][0]);
   *
   *   // Throws IllegalArgumentException since the shape of dst does not match the shape of t.
   *   int[][] dst = new int[4][1];
   *   t.copyTo(dst);
   * }
   * }
* * @throws IllegalArgumentException if the tensor is a scalar or if {@code dst} is not compatible * with the tensor (for example, mismatched data types or shapes). */ public U copyTo(U dst) { throwExceptionIfTypeIsIncompatible(dst); readNDArray(nativeHandle, dst); return dst; } /** * Write the data of a {@link Integer} tensor into the given buffer. * *

Copies {@code numElements()} elements to the buffer. * * @param dst the destination buffer * @throws BufferOverflowException If there is insufficient space in the given buffer for the data * in this tensor * @throws IllegalArgumentException If the tensor data type is not {@link Integer} */ public void writeTo(IntBuffer dst) { if (dtype != DataType.INT32) { throw incompatibleBuffer(dst, dtype); } ByteBuffer src = buffer(); dst.put(src.asIntBuffer()); } /** * Write the data of a {@link Float} tensor into the given buffer. * *

Copies {@code numElements()} elements to the buffer. * * @param dst the destination buffer * @throws BufferOverflowException If there is insufficient space in the given buffer for the data * in this tensor * @throws IllegalArgumentException If the tensor datatype is not {@link Float} */ public void writeTo(FloatBuffer dst) { if (dtype != DataType.FLOAT) { throw incompatibleBuffer(dst, dtype); } ByteBuffer src = buffer(); dst.put(src.asFloatBuffer()); } /** * Write the data of a {@link Double} tensor into the given buffer. * *

Copies {@code numElements()} elements to the buffer. * * @param dst the destination buffer * @throws BufferOverflowException If there is insufficient space in the given buffer for the data * in this tensor * @throws IllegalArgumentException If the tensor datatype is not {@link Double} */ public void writeTo(DoubleBuffer dst) { if (dtype != DataType.DOUBLE) { throw incompatibleBuffer(dst, dtype); } ByteBuffer src = buffer(); dst.put(src.asDoubleBuffer()); } /** * Write the data of a {@link Long} tensor into the given buffer. * *

Copies {@code numElements()} elements to the buffer. * * @param dst the destination buffer * @throws BufferOverflowException If there is insufficient space in the given buffer for the data * in this tensor * @throws IllegalArgumentException If the tensor datatype is not {@link Long} */ public void writeTo(LongBuffer dst) { if (dtype != DataType.INT64) { throw incompatibleBuffer(dst, dtype); } ByteBuffer src = buffer(); dst.put(src.asLongBuffer()); } /** * Write the tensor data into the given buffer. * *

Copies {@code numBytes()} bytes to the buffer in native byte order for primitive types. * * @param dst the destination buffer * @throws BufferOverflowException If there is insufficient space in the given buffer for the data * in this tensor */ public void writeTo(ByteBuffer dst) { ByteBuffer src = buffer(); dst.put(src); } /** Returns a string describing the type and shape of the Tensor. */ @Override public String toString() { return String.format("%s tensor with shape %s", dtype.toString(), Arrays.toString(shape())); } /** * Create a Tensor object from a handle to the C TF_Tensor object. * *

Takes ownership of the handle. */ static Tensor fromHandle(long handle) { @SuppressWarnings("rawtypes") Tensor t = new Tensor(DataType.fromC(dtype(handle))); t.shapeCopy = shape(handle); t.nativeHandle = handle; return t; } long getNativeHandle() { return nativeHandle; } private long nativeHandle; private DataType dtype; private long[] shapeCopy = null; private Tensor(DataType t) { dtype = t; } private ByteBuffer buffer() { return buffer(nativeHandle).order(ByteOrder.nativeOrder()); } private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) { return new IllegalArgumentException( String.format("cannot use %s with Tensor of type %s", buf.getClass().getName(), dataType)); } private static IllegalArgumentException incompatibleBuffer(int numElements, long[] shape) { return new IllegalArgumentException( String.format( "buffer with %d elements is not compatible with a Tensor with shape %s", numElements, Arrays.toString(shape))); } private static int numElements(long[] shape) { // assumes a fully-known shape int n = 1; for (int i = 0; i < shape.length; i++) { n *= (int) shape[i]; } return n; } private static int elemByteSize(DataType dataType) { switch (dataType) { case FLOAT: case INT32: return 4; case DOUBLE: case INT64: return 8; case BOOL: case UINT8: return 1; case STRING: throw new IllegalArgumentException("STRING tensors do not have a fixed element size"); } throw new IllegalArgumentException("DataType " + dataType + " is not supported yet"); } private static void throwExceptionIfNotByteOfByteArrays(Object array) { if (!array.getClass().getName().equals("[[B")) { throw new IllegalArgumentException( "object cannot be converted to a Tensor as it includes an array with null elements"); } } private static HashMap, DataType> classDataTypes = new HashMap<>(); static { classDataTypes.put(int.class, DataType.INT32); classDataTypes.put(Integer.class, DataType.INT32); classDataTypes.put(long.class, DataType.INT64); classDataTypes.put(Long.class, DataType.INT64); classDataTypes.put(float.class, DataType.FLOAT); classDataTypes.put(Float.class, DataType.FLOAT); classDataTypes.put(double.class, DataType.DOUBLE); classDataTypes.put(Double.class, DataType.DOUBLE); classDataTypes.put(byte.class, DataType.STRING); classDataTypes.put(Byte.class, DataType.STRING); classDataTypes.put(boolean.class, DataType.BOOL); classDataTypes.put(Boolean.class, DataType.BOOL); } /** The class for the data type to which Java object o corresponds. */ private static Class baseObjType(Object o) { Class c = o.getClass(); while (c.isArray()) { c = c.getComponentType(); } return c; } /** * The default TensorFlow data type to which Java object o corresponds. Some Java objects * represent more than one TensorFlow data type; for example, 'byte' can represent both {@code * uint8} and {@code string}, with the latter being the default interpretation. */ private static DataType dataTypeOf(Object o) { Class c = baseObjType(o); return dataTypeFromClass(c); } private static DataType dataTypeFromClass(Class c) { DataType ret = classDataTypes.get(c); if (ret != null) { return ret; } throw new IllegalArgumentException("cannot create Tensors of type " + c.getName()); } /** * Return the number of dimensions of the tensor that object {@code o} represents as a tensor * whose datatype is {@code dtype}. Normally this is the same as the number of dimensions of o * itself, but is one smaller for tensors of strings. * * @param o The object to inspect. It must be a valid representation of the given data type. * @param dtype The expected data type of the tensor. */ private static int numDimensions(Object o, DataType dtype) { int ret = numArrayDimensions(o); if (dtype == DataType.STRING && ret > 0) { return ret - 1; } return ret; } /** Returns the number of dimensions of the array object o. Returns 0 if o is not an array. */ private static int numArrayDimensions(Object o) { Class c = o.getClass(); int i = 0; while (c.isArray()) { c = c.getComponentType(); i++; } return i; } /** * Fills in the remaining entries in the shape array starting from position {@code dim} with the * dimension sizes of the multidimensional array o. Checks that all arrays reachable from o have * sizes consistent with the filled-in shape, throwing IllegalArgumentException otherwise. */ private static void fillShape(Object o, int dim, long[] shape) { if (shape == null || dim == shape.length) { return; } final int len = Array.getLength(o); if (len == 0) { throw new IllegalArgumentException("cannot create Tensors with a 0 dimension"); } if (shape[dim] == 0) { shape[dim] = len; } else if (shape[dim] != len) { throw new IllegalArgumentException( String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); } for (int i = 0; i < len; ++i) { fillShape(Array.get(o, i), dim + 1, shape); } } /** Returns whether the object {@code obj} can represent a tensor with data type {@code dtype}. */ private static boolean objectCompatWithType(Object obj, DataType dtype) { Class c = baseObjType(obj); DataType dto = dataTypeFromClass(c); int nd = numDimensions(obj, dto); if (!c.isPrimitive() && c != String.class && nd != 0) { throw new IllegalArgumentException( "cannot create non-scalar Tensors from arrays of boxed values"); } if (dto.equals(dtype)) { return true; } if (dto == DataType.STRING && dtype == DataType.UINT8) { return true; } return false; } private void throwExceptionIfTypeIsIncompatible(Object o) { final int rank = numDimensions(); final int oRank = numDimensions(o, dtype); if (oRank != rank) { throw new IllegalArgumentException( String.format( "cannot copy Tensor with %d dimensions into an object with %d", rank, oRank)); } if (!objectCompatWithType(o, dtype)) { throw new IllegalArgumentException( String.format( "cannot copy Tensor with DataType %s into an object of type %s", dtype.toString(), o.getClass().getName())); } long[] oShape = new long[rank]; fillShape(o, 0, oShape); for (int i = 0; i < oShape.length; ++i) { if (oShape[i] != shape()[i]) { throw new IllegalArgumentException( String.format( "cannot copy Tensor with shape %s into object with shape %s", Arrays.toString(shape()), Arrays.toString(oShape))); } } } private static native long allocate(int dtype, long[] shape, long byteSize); private static native long allocateScalarBytes(byte[] value); private static native long allocateNonScalarBytes(long[] shape, Object[] value); private static native void delete(long handle); private static native ByteBuffer buffer(long handle); private static native int dtype(long handle); private static native long[] shape(long handle); private static native void setValue(long handle, Object value); private static native float scalarFloat(long handle); private static native double scalarDouble(long handle); private static native int scalarInt(long handle); private static native long scalarLong(long handle); private static native boolean scalarBoolean(long handle); private static native byte[] scalarBytes(long handle); private static native void readNDArray(long handle, Object value); static { TensorFlow.init(); } }