
com.simiacryptus.tensorflow.TFUtil Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of tensorflow-model Show documentation
Show all versions of tensorflow-model Show documentation
General Utilities for TensorFlow
The newest version!
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.tensorflow;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.google.common.primitives.Floats;
import com.simiacryptus.util.Util;
import org.apache.commons.io.IOUtils;
import org.tensorflow.*;
import org.tensorflow.util.Event;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.awt.*;
import java.io.*;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.function.Consumer;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import java.util.zip.ZipFile;
public class TFUtil {
public static byte[] editGraph(@Nonnull byte[] protobufBinaryData, @Nonnull Consumer operator) {
byte[] newGraphDef;
try (Graph graph = new Graph()) {
graph.importGraphDef(protobufBinaryData);
operator.accept(graph);
newGraphDef = graph.toGraphDef();
}
return newGraphDef;
}
@Nullable
public static Operation find(@Nonnull Graph graph, String name) {
Iterator operations = graph.operations();
while (operations.hasNext()) {
Operation operation = operations.next();
if (operation.name().equals(name)) {
return operation;
}
}
return null;
}
public static byte[] loadZipUrl(@Nonnull String uri, @Nonnull String file) throws Exception {
try (ZipFile zipFile = new ZipFile(Util.cacheFile(new URI(uri)))) {
return IOUtils.toByteArray(zipFile.getInputStream(zipFile.getEntry(file)));
}
}
public static String toJson(Object output) throws JsonProcessingException {
return new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT).writeValueAsString(output);
}
public static double[] getFloatValues(@Nonnull Tensor result) {
final long[] shape = result.shape();
long length = Arrays.stream(shape).reduce((a, b) -> a * b).getAsLong();
if (shape.length == 1) {
return Floats.asList(result.copyTo(new float[(int) length])).stream().mapToDouble(x -> x).toArray();
} else if (shape.length == 2) {
float[][] floats = result.copyTo(IntStream.range(0, (int) shape[0]).mapToObj(i -> new float[(int) shape[1]]).toArray(s -> new float[s][]));
return Arrays.stream(floats).flatMapToDouble(x -> Floats.asList(x).stream().mapToDouble(f -> (double) (float) f)).toArray();
} else {
throw new RuntimeException(Arrays.toString(shape));
}
}
public static String describeGraph(@Nonnull Graph graph) {
OutputStream stringOutputStream = new ByteArrayOutputStream();
try (PrintStream outputStream = new PrintStream(stringOutputStream)) {
for (Iterator iter = graph.operations(); iter.hasNext(); ) {
Operation operation = iter.next();
outputStream.println(String.format("Operation %s (type %s) with %s outputs", operation.name(), operation.type(), operation.numOutputs()));
for (int i = 0; i < operation.numOutputs(); i++) {
Output
© 2015 - 2025 Weber Informatics LLC | Privacy Policy