
com.simiacryptus.tensorflow.GraphModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of tensorflow-model Show documentation
Show all versions of tensorflow-model Show documentation
General Utilities for TensorFlow
The newest version!
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.tensorflow;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.primitives.Floats;
import com.google.protobuf.InvalidProtocolBufferException;
import com.simiacryptus.util.Util;
import org.tensorflow.framework.*;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class GraphModel {
@JsonIgnore
public final GraphDef graphDef;
@JsonIgnore
public final Map nodeMap;
private final ConcurrentHashMap nodeCache = new ConcurrentHashMap<>();
public GraphModel(byte[] bytes) {
this.graphDef = parseGraph(bytes);
this.nodeMap = graphDef.getNodeList().stream().collect(Collectors.toMap(x -> x.getName(), x -> x));
}
public Map getNodes() {
return this.nodeMap.entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), x -> getChild(x.getKey())));
}
public List getRootNodes() {
return getNodes().entrySet().stream()
.filter(potentialRoot -> !getNodes().values().stream().filter(node -> node.getInputKeys().contains(potentialRoot.getKey())).findAny().isPresent())
.map(x -> x.getKey())
.collect(Collectors.toList());
}
public static GraphDef parseGraph(byte[] bytes) {
final GraphDef graphDef;
try {
graphDef = GraphDef.parseFrom(bytes);
} catch (InvalidProtocolBufferException e) {
throw Util.throwException(e);
}
return graphDef;
}
@Nullable
public static double[] toDoubles(@Nullable float[] values) {
return null == values ? null : Floats.asList(values).stream().mapToDouble(x -> x).toArray();
}
@Nullable
public static double[] toDoubles(@Nullable int[] values) {
return null == values ? null : Arrays.stream(values).mapToDouble(x -> x).toArray();
}
@Nullable
public static double[] toDoubles(@Nullable long[] values) {
return null == values ? null : Arrays.stream(values).mapToDouble(x -> x).toArray();
}
@Nullable
public static HashMap summary(@Nullable double[] values) {
if (null == values) return null;
DoubleSummaryStatistics finiteStatistics = Arrays.stream(values).filter(x -> Double.isFinite(x)).summaryStatistics();
long nanCount = Arrays.stream(values).filter(x -> Double.isNaN(x)).count();
long infCount = Arrays.stream(values).filter(x -> Double.isInfinite(x)).count();
HashMap data = new HashMap<>();
data.put("min", finiteStatistics.getMin());
data.put("max", finiteStatistics.getMax());
data.put("sum", finiteStatistics.getSum());
data.put("avg", finiteStatistics.getAverage());
data.put("finiteCount", (double) finiteStatistics.getCount());
data.put("nanCount", (double) nanCount);
data.put("infCount", (double) infCount);
return data;
}
@Nonnull
public static int[] getInts(@Nonnull ByteBuffer byteBuffer) {
int[] values = new int[byteBuffer.limit() / 4];
byte in[] = {0, 0, 0, 0};
for (int i = 0; i < values.length; i++) {
byteBuffer.get(in);
int value = 0;
for (int b = 0; b < 4; b++) {
value |= (in[b] & 0xFFL) << 8 * b;
}
values[i] = value;
}
return values;
}
@Nonnull
public static long[] getLongs(@Nonnull ByteBuffer byteBuffer) {
long[] values = new long[byteBuffer.limit() / 8];
byte in[] = {0, 0, 0, 0, 0, 0, 0, 0};
for (int i = 0; i < values.length; i++) {
byteBuffer.get(in);
long value = 0;
for (int b = 0; b < 4; b++) {
value |= (in[b] & 0xFFL) << 8 * b;
}
values[i] = value;
}
return values;
// long[] values = new long[byteBuffer.limit() / 8];
// byteBuffer.asLongBuffer().get(values);
// return values;
}
@Nonnull
public static ByteBuffer putDoubles(@Nonnull double[] values) {
return (ByteBuffer) putDoubles(ByteBuffer.allocate(values.length * 8), values).flip();
}
@Nonnull
public static ByteBuffer putDoubles(@Nonnull ByteBuffer byteBuffer, @Nonnull double[] values) {
byte in[] = {0, 0, 0, 0, 0, 0, 0, 0};
for (int i = 0; i < values.length; i++) {
long value = Double.doubleToLongBits(values[i]);
for (int b = 0; b < 8; b++) {
in[b] = (byte) (value >> 8 * b & 0xFFL);
}
byteBuffer.put(in);
}
return byteBuffer;
}
@Nonnull
public static double[] getDoubles(@Nonnull ByteBuffer byteBuffer) {
double[] values = new double[byteBuffer.limit() / 8];
byte in[] = {0, 0, 0, 0, 0, 0, 0, 0};
for (int i = 0; i < values.length; i++) {
byteBuffer.get(in);
long value = 0;
for (int b = 0; b < 8; b++) {
value |= (in[b] & 0xFFL) << 8 * b;
}
values[i] = Double.longBitsToDouble(value);
}
return values;
}
@Nonnull
public static ByteBuffer putFloats(@Nonnull float[] values) {
return (ByteBuffer) putFloats(ByteBuffer.allocate(values.length * 8), values).flip();
}
@Nonnull
public static ByteBuffer putFloats(@Nonnull ByteBuffer byteBuffer, @Nonnull float[] values) {
byte in[] = {0, 0, 0, 0};
for (int i = 0; i < values.length; i++) {
int value = Float.floatToIntBits(values[i]);
for (int b = 0; b < 4; b++) {
in[b] = (byte) (value >> 8 * b & 0xFFL);
}
byteBuffer.put(in);
}
return byteBuffer;
}
@Nonnull
public static float[] getFloats(@Nonnull ByteBuffer byteBuffer) {
float[] values = new float[byteBuffer.limit() / 4];
byte in[] = {0, 0, 0, 0};
for (int i = 0; i < values.length; i++) {
byteBuffer.get(in);
int value = 0;
for (int b = 0; b < 4; b++) {
value |= (in[b] & 0xFF) << 8 * b;
}
values[i] = Float.intBitsToFloat(value);
}
return values;
}
public Map compare(@Nonnull GraphModel other) {
final Map myNodes = getNodes();
final Map yourNodes = other.getNodes();
return Stream.concat(
myNodes.keySet().stream(),
yourNodes.keySet().stream()
).distinct().flatMap(name -> {
final GraphNode l = myNodes.get(name);
final GraphNode r = yourNodes.get(name);
boolean equals = true;
if (null == l || null == r) return Stream.of(new DeltaRecord(name, l, r));
equals = l.getProperties().equals(r.getProperties());
equals &= Arrays.equals(l.getShape(), r.getShape());
equals &= Arrays.equals(l.getData(), r.getData());
equals &= l.getInputKeys().equals(r.getInputKeys());
equals &= l.getInputs().stream().map(x -> x.name).collect(Collectors.toSet()).equals(
r.getInputs().stream().map(x -> x.name).collect(Collectors.toSet())
);
return equals ? Stream.empty() : Stream.of(new DeltaRecord(name, l, r));
}).collect(Collectors.toMap(x -> x.name, x -> x));
}
public GraphNode getChild(@Nonnull String x) {
return nodeCache.computeIfAbsent(x, y -> new GraphNode(y));
}
public static class DeltaRecord {
public final String name;
public final GraphNode left;
public final GraphNode right;
public DeltaRecord(String name, GraphNode left, GraphNode right) {
this.name = name;
this.left = left;
this.right = right;
}
}
public class GraphNode {
public final String name;
@Nullable
@JsonIgnore
private transient NodeDef nodeDef = null;
@Nullable
private transient String op = null;
@Nullable
private volatile Integer order = null;
@Nullable
private volatile List rootInputs = null;
protected GraphNode(String name) {
this.name = name;
}
@Nullable
@JsonIgnore
public double[] getData() {
DataType dataType = getDataType();
if (dataType == DataType.DT_DOUBLE) {
return getDoubles();
} else if (dataType == DataType.DT_FLOAT) {
return toDoubles(getFloats());
} else if (dataType == DataType.DT_INT32) {
return toDoubles(getInts());
} else if (dataType == DataType.DT_INT64) {
return toDoubles(getLongs());
} else if (dataType == DataType.UNRECOGNIZED) {
return null;
} else {
throw new RuntimeException("Unsupported Data Type: " + dataType);
}
}
@Nullable
public Object getDataSummary() {
if (getDataType() == DataType.DT_STRING) {
return "";
} else {
double[] data = getData();
if (null == data || 0 == data.length) return "";
if (32 > data.length) return data;
return summary(data);
}
}
@JsonIgnore
@Nonnull
public final DataType getDataType() {
Map attrMap = getNodeDef().getAttrMap();
DataType type = null;
if (attrMap.containsKey("dtype")) {
type = attrMap.get("dtype").getType();
}
if (null == type && attrMap.containsKey("value")) {
type = attrMap.get("value").getTensor().getDtype();
}
if (null == type) {
type = DataType.UNRECOGNIZED;
}
return type;
}
@Nullable
public String getDataTypeString() {
DataType dataType = getDataType();
if (dataType == DataType.UNRECOGNIZED) return null;
return dataType.toString();
}
@Nullable
@JsonIgnore
public double[] getDoubles() {
TensorProto tensor = getTensor();
return null == tensor ? null : GraphModel.getDoubles(tensor.getTensorContent().asReadOnlyByteBuffer());
}
@Nullable
@JsonIgnore
public float[] getFloats() {
TensorProto tensor = getTensor();
return null == tensor ? null : GraphModel.getFloats(tensor.getTensorContent().asReadOnlyByteBuffer());
}
public List getInputKeys() {
return getInputs().stream().map(x -> x.name).collect(Collectors.toList());
}
@JsonIgnore
public List getInputs() {
final NodeDef nodeDef = this.getNodeDef();
if (null == nodeDef) return Collections.emptyList();
return nodeDef.getInputList().stream().map(x -> getChild(x)).collect(Collectors.toList());
}
@Nullable
@JsonIgnore
public int[] getInts() {
TensorProto tensor = getTensor();
return null == tensor ? null : GraphModel.getInts(tensor.getTensorContent().asReadOnlyByteBuffer());
}
@Nullable
@JsonIgnore
public long[] getLongs() {
TensorProto tensor = getTensor();
return null == tensor ? null : GraphModel.getLongs(tensor.getTensorContent().asReadOnlyByteBuffer());
}
@Nullable
@JsonIgnore
public NodeDef getNodeDef() {
if (null == this.nodeDef) {
synchronized (this) {
if (null == this.nodeDef) {
String key = normalizeKey();
this.nodeDef = nodeMap.get(key);
if (null == this.nodeDef) {
return null;
}
}
}
}
return nodeDef;
}
@Nullable
public String getOp() {
if (null == this.op) {
synchronized (this) {
if (null == this.op) {
final NodeDef nodeDef = getNodeDef();
if (null == nodeDef) return "null";
this.op = nodeDef.getOp();
}
}
}
return op;
}
public int getOrder() {
if (null == order) {
synchronized (this) {
if (null == order) {
order = getInputs().stream().mapToInt(x -> x.getOrder()).max().orElseGet(() -> -1) + 1;
}
}
}
return order;
}
public Map getProperties() {
Map attrMap = getNodeDef().getAttrMap();
return attrMap.entrySet().stream()
.collect(Collectors.toMap(stringAttrValueEntry1 -> stringAttrValueEntry1.getKey(), stringAttrValueEntry -> {
if (stringAttrValueEntry.getKey().equals("value")) {
AttrValue value = stringAttrValueEntry.getValue();
final String s = value.toBuilder().build().toString();
return s.substring(0, Math.min(s.length(), 1024));
} else {
AttrValue value = stringAttrValueEntry.getValue();
return value.toString();
}
}));
}
@Nullable
@JsonIgnore
public List getRootInputs() {
if (null == rootInputs) {
synchronized (this) {
if (null == rootInputs) {
if (getInputs().isEmpty()) {
if (getOp().equals("Placeholder")) {
rootInputs = Arrays.asList(this);
} else {
rootInputs = Arrays.asList();
}
} else {
rootInputs = getInputs().stream().flatMap(input -> input.getRootInputs().stream()).distinct().collect(Collectors.toList());
}
}
}
}
return rootInputs;
}
public List getRootKeys() {
return getRootInputs().stream().map(x -> x.name).collect(Collectors.toList());
}
@Nullable
public long[] getShape() {
Map attrMap = getNodeDef().getAttrMap();
TensorShapeProto shape;
if (attrMap.containsKey("shape")) {
shape = attrMap.get("shape").getShape();
} else if (attrMap.containsKey("value")) {
shape = attrMap.get("value").getTensor().getTensorShape();
} else {
return null;
}
return shape.getDimList().stream().mapToLong(x -> x.getSize()).toArray();
}
@Nullable
@JsonIgnore
public TensorProto getTensor() {
Map attrMap = getNodeDef().getAttrMap();
return !attrMap.containsKey("value") ? null : attrMap.get("value").getTensor();
}
@Nonnull
@Override
public String toString() {
return "GraphNode{" +
"name='" + name + '\'' +
", op=" + getOp() +
", dataType=" + getDataType() +
", shape=" + getShape() +
", inputs=" + getInputs() +
'}';
}
@Nonnull
public GraphDef subgraph(@Nonnull Set inputs) {
GraphDef.Builder builder = GraphDef.newBuilder();
for (String input : inputs) {
builder.addNode(NodeDef.newBuilder()
.setName(input)
.setOp("Placeholder")
.putAttr("dtype", AttrValue.newBuilder().setType(DataType.DT_FLOAT).build())
.build());
}
subgraphNodes(inputs).stream().map(graphNode -> graphNode.getNodeDef()).forEach(value -> builder.addNode(value));
return builder.build();
}
@JsonIgnore
public List subgraphNodes(@Nonnull Set inputs) {
List subgraph;
if (inputs.contains(name)) {
subgraph = Arrays.asList();
} else if (getInputs().isEmpty()) {
subgraph = Arrays.asList(this);
} else {
subgraph = Stream.concat(
Stream.of(this),
getInputs().stream().flatMap(input -> input.subgraphNodes(inputs).stream()).distinct()
).distinct().collect(Collectors.toList());
}
return subgraph;
}
@Override
public boolean equals(@Nullable Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
GraphNode graphNode = (GraphNode) o;
return Objects.equals(name, graphNode.name);
}
@Override
public int hashCode() {
return Objects.hash(name);
}
private String normalizeKey() {
String name = this.name;
name = name.split(":")[0];
while (name.startsWith("^")) name = name.substring(1);
return name;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy