com.simiacryptus.mindseye.lang.tensorflow.TFIO Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.lang.tensorflow;
import com.google.common.primitives.Floats;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.mindseye.lang.TensorArray;
import com.simiacryptus.mindseye.lang.TensorList;
import com.simiacryptus.ref.lang.RecycleBin;
import com.simiacryptus.ref.wrappers.*;
import com.simiacryptus.util.Util;
import org.jetbrains.annotations.NotNull;
import org.tensorflow.DataType;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.util.stream.Stream;
public class TFIO {
@NotNull
public static TensorArray getTensorList(org.tensorflow.Tensor> tensor) {
return getTensorList(tensor, true);
}
@NotNull
public static TensorArray getTensorList(org.tensorflow.Tensor> tensor, boolean invertRanks) {
if (tensor.dataType() == DataType.DOUBLE) {
return getTensorArray_Double(tensor.expect(Double.class), tensor.shape(), invertRanks);
} else if (tensor.dataType() == DataType.FLOAT) {
return getTensorArray_Float(tensor.expect(Float.class), tensor.shape(), invertRanks);
} else {
throw new IllegalArgumentException(tensor.dataType().toString());
}
}
@NotNull
public static Tensor getTensor(org.tensorflow.Tensor> tensor) {
return getTensor(tensor, true);
}
@NotNull
public static Tensor getTensor(org.tensorflow.Tensor> tensor, boolean invertRanks) {
if (tensor.dataType() == DataType.DOUBLE) {
return getTensor_Double(tensor.expect(Double.class), tensor.shape(), invertRanks);
} else if (tensor.dataType() == DataType.FLOAT) {
return getTensor_Float(tensor.expect(Float.class), tensor.shape(), invertRanks);
} else {
throw new IllegalArgumentException(tensor.dataType().toString());
}
}
@NotNull
public static org.tensorflow.Tensor getFloatTensor(@Nullable Tensor data) {
return getFloatTensor(data, true);
}
@NotNull
public static org.tensorflow.Tensor getFloatTensor(@NotNull Tensor data, boolean invertRanks) {
final Tensor invertDimensions;
double[] buffer;
if (invertRanks) {
invertDimensions = data.invertDimensions();
buffer = invertDimensions.getData();
} else {
invertDimensions = null;
buffer = data.getData();
}
org.tensorflow.Tensor tfTensor = org.tensorflow.Tensor.create(
Util.toLong(data.getDimensions()),
FloatBuffer.wrap(Util.getFloats(buffer))
);
if (null != invertDimensions)
invertDimensions.freeRef();
data.freeRef();
return tfTensor;
}
@NotNull
public static org.tensorflow.Tensor getFloatTensor(@Nullable TensorList data) {
return getFloatTensor(data, true);
}
@NotNull
public static org.tensorflow.Tensor getFloatTensor(@Nullable TensorList data, boolean invertRanks) {
long[] shape = RefLongStream.concat(
RefLongStream.of(data.length()),
RefArrays.stream(data.getDimensions()).mapToLong(x -> x)
).toArray();
double[] buffer = getDoubles(data, invertRanks);
org.tensorflow.Tensor tensor = org.tensorflow.Tensor.create(shape,
FloatBuffer.wrap(Util.getFloats(buffer)));
RecycleBin.DOUBLES.recycle(buffer, buffer.length);
return tensor;
}
@NotNull
public static org.tensorflow.Tensor getDoubleTensor(@Nullable Tensor data) {
org.tensorflow.Tensor temp_03_0009 = getDoubleTensor(data == null ? null : data.addRef(), true);
if (null != data)
data.freeRef();
return temp_03_0009;
}
@NotNull
public static org.tensorflow.Tensor getDoubleTensor(@NotNull Tensor data, boolean invertRanks) {
double[] buffer;
final Tensor invertDimensions;
if (invertRanks) {
invertDimensions = data.invertDimensions();
buffer = invertDimensions.getData();
} else {
invertDimensions = null;
buffer = data.getData();
}
org.tensorflow.Tensor tfTensor = org.tensorflow.Tensor.create(
Util.toLong(data.getDimensions()),
DoubleBuffer.wrap(buffer)
);
data.freeRef();
if (null != invertDimensions)
invertDimensions.freeRef();
return tfTensor;
}
@NotNull
public static org.tensorflow.Tensor getDoubleTensor(@Nullable TensorList data) {
org.tensorflow.Tensor temp_03_0011 = getDoubleTensor(data == null ? null : data.addRef(), true);
if (null != data)
data.freeRef();
return temp_03_0011;
}
@NotNull
public static org.tensorflow.Tensor getDoubleTensor(@Nullable TensorList data, boolean invertRanks) {
long[] shape = RefLongStream.concat(
RefLongStream.of(data.length()),
RefArrays.stream(data.getDimensions()).mapToLong(x -> x)
).toArray();
double[] buffer = getDoubles(data, invertRanks);
org.tensorflow.Tensor tensor = org.tensorflow.Tensor.create(shape, DoubleBuffer.wrap(buffer));
RecycleBin.DOUBLES.recycle(buffer, buffer.length);
return tensor;
}
private static void free(Object obj) {
if (obj instanceof double[]) {
double[] doubles = (double[]) obj;
RecycleBin.DOUBLES.recycle(doubles, doubles.length);
} else if (obj instanceof float[]) {
float[] floats = (float[]) obj;
RecycleBin.FLOATS.recycle(floats, floats.length);
} else {
RefArrays.stream((Object[]) obj).forEach(x -> free(x));
}
}
private static Object createFloatArray(@NotNull long[] shape) {
if (shape.length == 1) {
return RecycleBin.FLOATS.obtain(shape[0]);
} else if (shape.length == 2) {
return RefIntStream.range(0, (int) shape[0]).mapToObj(i -> new float[(int) shape[1]])
.toArray(s -> new float[s][]);
} else if (shape.length == 3) {
return RefIntStream.range(0, (int) shape[0]).mapToObj(i -> RefIntStream.range(0, (int) shape[1])
.mapToObj(j -> new float[(int) shape[2]]).toArray(s -> new float[s][])).toArray(s -> new float[s][][]);
} else if (shape.length == 4) {
return RefIntStream.range(0, (int) shape[0])
.mapToObj(i -> RefIntStream
.range(0, (int) shape[1]).mapToObj(j -> RefIntStream.range(0, (int) shape[2])
.mapToObj(k -> new float[(int) shape[3]]).toArray(s -> new float[s][]))
.toArray(s -> new float[s][][]))
.toArray(s -> new float[s][][][]);
} else if (shape.length == 5) {
return RefIntStream.range(0, (int) shape[0])
.mapToObj(i -> RefIntStream.range(0, (int) shape[1])
.mapToObj(j -> RefIntStream.range(0, (int) shape[2])
.mapToObj(k -> RefIntStream.range(0, (int) shape[3]).mapToObj(l -> new float[(int) shape[4]])
.toArray(s -> new float[s][]))
.toArray(s -> new float[s][][]))
.toArray(s -> new float[s][][][]))
.toArray(s -> new float[s][][][][]);
} else if (shape.length == 6) {
return RefIntStream.range(0, (int) shape[0]).mapToObj(i -> RefIntStream.range(0, (int) shape[1])
.mapToObj(j -> RefIntStream.range(0, (int) shape[2])
.mapToObj(k -> RefIntStream.range(0, (int) shape[3])
.mapToObj(l -> RefIntStream.range(0, (int) shape[4]).mapToObj(m -> new float[(int) shape[5]])
.toArray(s -> new float[s][]))
.toArray(s -> new float[s][][]))
.toArray(s -> new float[s][][][]))
.toArray(s -> new float[s][][][][])).toArray(s -> new float[s][][][][][]);
} else {
throw new RuntimeException("Rank " + shape.length);
}
}
private static Object createDoubleArray(@NotNull long[] shape) {
if (shape.length == 1) {
return RecycleBin.DOUBLES.obtain(shape[0]);
} else if (shape.length == 2) {
return RefIntStream.range(0, (int) shape[0]).mapToObj(i -> new double[(int) shape[1]])
.toArray(s -> new double[s][]);
} else if (shape.length == 3) {
return RefIntStream.range(0, (int) shape[0]).mapToObj(i -> RefIntStream.range(0, (int) shape[1])
.mapToObj(j -> new double[(int) shape[2]]).toArray(s -> new double[s][])).toArray(s -> new double[s][][]);
} else if (shape.length == 4) {
return RefIntStream.range(0, (int) shape[0])
.mapToObj(
i -> RefIntStream.range(0, (int) shape[1])
.mapToObj(j -> RefIntStream.range(0, (int) shape[2]).mapToObj(k -> new double[(int) shape[3]])
.toArray(s -> new double[s][]))
.toArray(s -> new double[s][][]))
.toArray(s -> new double[s][][][]);
} else {
throw new RuntimeException("Rank " + shape.length);
}
}
@NotNull
private static RefDoubleStream flattenDoubles(Object obj) {
if (obj instanceof double[]) {
return RefArrays.stream((double[]) obj);
} else if (obj instanceof Double) {
return RefDoubleStream.of((double) obj);
} else {
return RefArrays.stream((Object[]) obj).flatMapToDouble(obj1 -> flattenDoubles(obj1));
}
}
private static Stream flattenFloats(Object floats) {
if (floats instanceof float[]) {
float[] array = (float[]) floats;
return Floats.asList(array).stream();
} else {
return RefArrays.stream((Object[]) floats).flatMap(x -> flattenFloats(x));
}
}
private static double[] getDoubles(@NotNull TensorList data, boolean invertRanks) {
double[] buffer = RecycleBin.DOUBLES.obtain(data.length() * Tensor.length(data.getDimensions()));
DoubleBuffer inputBuffer = DoubleBuffer.wrap(buffer);
if (invertRanks) {
data.stream().map(tensor -> {
Tensor invertDimensions = tensor.invertDimensions();
tensor.freeRef();
return invertDimensions;
}).forEach(tensor -> {
inputBuffer.put(tensor.getData());
tensor.freeRef();
});
} else {
data.stream().forEach(tensor -> {
inputBuffer.put(tensor.getData());
tensor.freeRef();
});
}
data.freeRef();
return buffer;
}
@NotNull
private static TensorArray getTensorArray_Float(org.tensorflow.Tensor tensor, @NotNull long[] shape,
boolean invertRanks) {
float[] doubles = getFloats(tensor);
int[] dims = RefArrays.stream(shape).skip(1).mapToInt(x -> (int) x).toArray();
int batches = (int) shape[0];
TensorArray resultData = new TensorArray(RefIntStream.range(0, batches).mapToObj(i -> {
int offset = i * Tensor.length(dims);
if (invertRanks) {
Tensor returnValue = new Tensor(Tensor.reverse(dims));
returnValue.set(j -> doubles[j + offset]);
Tensor invertDimensions = returnValue.invertDimensions();
returnValue.freeRef();
return invertDimensions;
} else {
Tensor returnValue = new Tensor(dims);
returnValue.set(j -> doubles[j + offset]);
return returnValue;
}
}).toArray(i -> new Tensor[i]));
RecycleBin.FLOATS.recycle(doubles, doubles.length);
return resultData;
}
@NotNull
private static Tensor getTensor_Float(org.tensorflow.Tensor tensor, @NotNull long[] shape, boolean invertRanks) {
if (0 == tensor.numElements())
return new Tensor(RefArrays.stream(shape).mapToInt(x -> (int) x).toArray());
float[] doubles = getFloats(tensor);
int[] dims = RefArrays.stream(shape).mapToInt(x -> (int) x).toArray();
if (invertRanks) {
Tensor returnValue = new Tensor(Tensor.reverse(dims));
returnValue.set(j -> doubles[j]);
RecycleBin.FLOATS.recycle(doubles, doubles.length);
Tensor invertDimensions = returnValue.invertDimensions();
returnValue.freeRef();
return invertDimensions;
} else {
Tensor returnValue = new Tensor(dims);
returnValue.set(j -> doubles[j]);
RecycleBin.FLOATS.recycle(doubles, doubles.length);
return returnValue;
}
}
@NotNull
private static TensorArray getTensorArray_Double(org.tensorflow.Tensor tensor, @NotNull long[] shape,
boolean invertRanks) {
double[] doubles = getDoubles(tensor);
int[] dims = RefArrays.stream(shape).skip(1).mapToInt(x -> (int) x).toArray();
int batches = (int) shape[0];
TensorArray resultData = new TensorArray(RefIntStream.range(0, batches).mapToObj(i -> {
if (invertRanks) {
Tensor returnValue = new Tensor(Tensor.reverse(dims));
RefSystem.arraycopy(doubles, i * returnValue.length(), returnValue.getData(), 0,
returnValue.length());
Tensor invertDimensions = returnValue.invertDimensions();
returnValue.freeRef();
return invertDimensions;
} else {
Tensor returnValue = new Tensor(dims);
RefSystem.arraycopy(doubles, i * returnValue.length(), returnValue.getData(), 0,
returnValue.length());
return returnValue;
}
}).toArray(i -> new Tensor[i]));
RecycleBin.DOUBLES.recycle(doubles, doubles.length);
return resultData;
}
@NotNull
private static Tensor getTensor_Double(org.tensorflow.Tensor tensor, @NotNull long[] shape, boolean invertRanks) {
double[] doubles = getDoubles(tensor);
int[] dims = RefArrays.stream(shape).mapToInt(x -> (int) x).toArray();
if (invertRanks) {
Tensor returnValue = new Tensor(Tensor.reverse(dims));
RefSystem.arraycopy(doubles, 0, returnValue.getData(), 0, returnValue.length());
RecycleBin.DOUBLES.recycle(doubles, doubles.length);
Tensor invertDimensions = returnValue.invertDimensions();
returnValue.freeRef();
return invertDimensions;
} else {
Tensor returnValue = new Tensor(dims);
RefSystem.arraycopy(doubles, 0, returnValue.getData(), 0, returnValue.length());
RecycleBin.DOUBLES.recycle(doubles, doubles.length);
return returnValue;
}
}
private static double[] getDoubles(org.tensorflow.Tensor result) {
Object deepArray = result.copyTo(createDoubleArray(result.shape()));
double[] doubles = flattenDoubles(deepArray).toArray();
free(deepArray);
return doubles;
}
@NotNull
private static float[] getFloats(org.tensorflow.Tensor result) {
if (0 == result.numElements())
return new float[]{};
Object deepArray = result.copyTo(createFloatArray(result.shape()));
double[] doubles = flattenFloats(deepArray).mapToDouble(x -> x).toArray();
free(deepArray);
return Util.getFloats(doubles);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy