Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.imports.graphmapper.tf;
import com.github.os72.protobuf351.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.imports.graphmapper.BaseGraphMapper;
import org.nd4j.imports.graphmapper.ImportState;
import org.nd4j.imports.graphmapper.OpImportFilter;
import org.nd4j.imports.graphmapper.OpImportOverride;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.*;
import java.io.*;
import java.nio.ByteOrder;
import java.util.*;
/**
* Map tensorflow graph protos
* to the intermediate representation
* for samediff.
*
* @author Adam Gibson
*/
@Slf4j
public class TFGraphMapper extends BaseGraphMapper {
private Set seenNodes = new LinkedHashSet<>();
public final static String VALUE_ATTR_KEY = "value";
public final static String SHAPE_KEY = "shape";
private static TFGraphMapper MAPPER_INSTANCE = new TFGraphMapper();
private Set graphMapper = new HashSet(){{
//While and If
//While -> Enter
/**
* Need to work on coping with variables
* that are marked as "shouldSkip"
*
* Possibly consider replacing should skip
* with a special handler interface. Something like
*
* public interface ImportOpHandler
*/
add("LoopCond");
/**
* We should skip this for the sake of while..but not if.
* Need to be a bit more flexible here.
*/
add("Merge");
add("Exit");
add("NextIteration");
add("NoOp");
add("Switch");
}};
//singleton
private TFGraphMapper() {}
/**
* Singleton. Get the needed instance.
* @return
*/
public static TFGraphMapper getInstance() {
return MAPPER_INSTANCE;
}
@Override
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
try {
GraphDef graphDef = GraphDef.parseFrom(inputFile);
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(NodeDef node : graphDef.getNodeList()) {
bufferedWriter.write(node.toString());
}
bufferedWriter.flush();
bufferedWriter.close();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public boolean isOpIgnoreException(NodeDef node) {
//if statements should not be ignored
/*
if(node.getOp().equals("Merge")) {
boolean ret = false;
for(int i = 0; i < node.getInputCount(); i++) {
//while loop
ret = ret || !node.getInput(i).endsWith("/Enter") || !node.getInput(i).endsWith("/NextIteration");
}
return ret;
}
else if(node.getOp().equals("Switch")) {
boolean ret = false;
for(int i = 0; i < node.getInputCount(); i++) {
//while loop
ret = ret || !node.getInput(i).endsWith("/Merge") || !node.getInput(i).endsWith("/LoopCond");
}
return ret;
}
*/
return true;
}
@Override
public String getTargetMappingForOp(DifferentialFunction function, NodeDef node) {
return function.opName();
}
@Override
public NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) {
for(int i = 0; i < graph.getNodeCount(); i++) {
val node = graph.getNode(i);
if(node.getName().equals(name))
return node;
}
return null;
}
@Override
public void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map> propertyMappingsForFunction) {
if(node == null) {
throw new ND4JIllegalStateException("No node found for name " + name);
}
val mapping = propertyMappingsForFunction.get(getOpType(node)).get(name);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
if(mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) {
int tfMappingIdx = mapping.getTfInputPosition();
if(tfMappingIdx < 0)
tfMappingIdx += node.getInputCount();
val input = node.getInput(tfMappingIdx);
val inputNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,input);
INDArray arr = getArrayFrom(inputNode,graph);
if(arr == null && sameDiff.hasVariable(input)) {
arr = sameDiff.getArrForVarName(input);
}
if(arr == null && inputNode != null) {
sameDiff.addPropertyToResolve(on,name);
sameDiff.addVariableMappingForField(on,name,getNodeName(inputNode.getName()));
return;
} else if(inputNode == null) {
//TODO need to do anything here given new design?
//sameDiff.addAsPlaceHolder(input);
return;
}
val field = fields.get(name);
val type = field.getType();
if(type.equals(int[].class)) {
on.setValueFor(field,arr.data().asInt());
}
else if(type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) {
if(mapping.getShapePosition() != null) {
on.setValueFor(field,arr.size(mapping.getShapePosition()));
}
else
on.setValueFor(field,arr.getInt(0));
}
else if(type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) {
on.setValueFor(field,arr.getDouble(0));
}
}
else {
val tfMappingAttrName = mapping.getTfAttrName();
if(tfMappingAttrName == null) {
return;
}
if(!node.containsAttr(tfMappingAttrName)) {
return;
}
val attr = node.getAttrOrThrow(tfMappingAttrName);
val type = attr.getType();
if(fields == null) {
throw new ND4JIllegalStateException("No fields found for op [" + mapping + "]");
}
if(mapping.getPropertyNames() == null) {
throw new ND4JIllegalStateException("no property found for [" + name + "] in op [" + on.opName()+"]");
}
val field = fields.get(mapping.getPropertyNames()[0]);
Object valueToSet = null;
switch(type) {
case DT_BOOL:
valueToSet = attr.getB();
break;
case DT_INT8:
valueToSet = attr.getI();
break;
case DT_INT16:
valueToSet = attr.getI();
break;
case DT_INT32:
valueToSet = attr.getI();
break;
case DT_FLOAT:
valueToSet = attr.getF();
break;
case DT_DOUBLE:
valueToSet = attr.getF();
break;
case DT_STRING:
valueToSet = attr.getS();
break;
case DT_INT64:
valueToSet = attr.getI();
break;
}
if(field != null && valueToSet != null)
on.setValueFor(field,valueToSet);
}
}
/**
* {@inheritDoc}
*/
@Override
public boolean isPlaceHolderNode(NodeDef node) {
return node.getOp().startsWith("Placeholder");
}
/**
* {@inheritDoc}
*/
@Override
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
try {
GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(NodeDef node : graphDef.getNodeList()) {
bufferedWriter.write(node.toString());
}
bufferedWriter.flush();
bufferedWriter.close();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public long[] getShapeFromAttr(AttrValue attr) {
return shapeFromShapeProto(attr.getShape());
}
@Override
public Map getAttrMap(NodeDef nodeDef) {
return nodeDef.getAttrMap();
}
@Override
public String getName(NodeDef nodeDef) {
return nodeDef.getName();
}
@Override
public boolean alreadySeen(NodeDef nodeDef) {
return seenNodes.contains(nodeDef.getName());
}
@Override
public boolean isVariableNode(NodeDef nodeDef) {
boolean isVar = nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const");
return isVar;
}
@Override
public boolean shouldSkip(NodeDef opType) {
if(opType == null)
return true;
boolean endsWithRead = opType.getName().endsWith("/read");
return endsWithRead;
}
@Override
public boolean hasShape(NodeDef nodeDef) {
return nodeDef.containsAttr(SHAPE_KEY);
}
@Override
public long[] getShape(NodeDef nodeDef) {
return getShapeFromAttr(nodeDef.getAttrOrThrow(SHAPE_KEY));
}
@Override
public INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) {
if(nodeDef == null) {
return null;
}
return getNDArrayFromTensor(nodeDef.getName(),nodeDef, graph);
}
@Override
public String getOpType(NodeDef nodeDef) {
return nodeDef.getOp();
}
/**
*
* @param graphDef
* @return
*/
@Override
public List getNodeList(GraphDef graphDef) {
return graphDef.getNodeList();
}
/**
*
* @param name the tensorflow or onnx name
* @return
*/
@Override
public DifferentialFunction getMappedOp(String name) {
return DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(name);
}
/**
* Map a tensorflow node name
* to the samediff equivalent
* for import
* @param name the name to change
* @return the input tensorflow name
*/
public String getNodeName(String name) {
//tensorflow adds colons to the end of variables representing input index, this strips those off
String ret = name;
if(ret.startsWith("^"))
ret = ret.substring(1);
if(ret.endsWith("/read")) {
ret = ret.replace("/read","");
}
if(ret.endsWith(":0")){
ret = ret.substring(0, ret.length()-2);
}
return ret;
}
public boolean isControlDependency(String name){
return name.startsWith("^");
}
@Override
public Map variablesForGraph(GraphDef graphDef) {
Map ret = new LinkedHashMap<>();
List nodeList = graphDef.getNodeList();
for(NodeDef nodeDef : nodeList) {
if(nodeDef.getName().endsWith("/read")) {
continue;
}
val name = translateToSameDiffName(nodeDef.getName(), nodeDef);
ret.put(name,nodeDef);
}
return ret;
}
@Override
public String translateToSameDiffName(String name, NodeDef node) {
if(isVariableNode(node) || isPlaceHolder(node)) {
return name;
}
StringBuilder stringBuilder = new StringBuilder();
//strip arg number
if(name.contains(":")) {
name = name.substring(0,name.lastIndexOf(':'));
stringBuilder.append(name);
}
else {
stringBuilder.append(name);
}
return stringBuilder.toString();
}
//Strip the variable suffix to give the node name: "Unique:1" -> "Unique"
public String varNameToOpName(String varName){
int idx = varName.lastIndexOf(':');
if(idx < 0)
return varName;
return varName.substring(0, idx);
}
public static int varNameToOpOutputNumber(String varName){
int idx = varName.lastIndexOf(':');
if(idx < 0)
return 0;
String n = varName.substring(idx+1);
return Integer.parseInt(n);
}
@Override
public Message.Builder getNewGraphBuilder() {
return GraphDef.newBuilder();
}
@Override
public GraphDef parseGraphFrom(byte[] inputStream) throws IOException {
return GraphDef.parseFrom(inputStream);
}
@Override
public GraphDef parseGraphFrom(InputStream inputStream) throws IOException {
return GraphDef.parseFrom(inputStream);
}
protected void importCondition(String conditionName, NodeDef tfNode, ImportState importState) {
/**
* Cond structure:
*
*/
}
@Override
public void mapNodeType(NodeDef tfNode, ImportState importState,
OpImportOverride importOverride,
OpImportFilter opFilter) {
if (shouldSkip(tfNode) || alreadySeen(tfNode) || isVariableNode(tfNode)) {
return;
}
SameDiff diff = importState.getSameDiff();
if (isVariableNode(tfNode)) {
List dimensions = new ArrayList<>();
Map attributes = getAttrMap(tfNode);
if (attributes.containsKey(VALUE_ATTR_KEY)) {
diff.var(getName(tfNode),getArrayFrom(tfNode,importState.getGraph()));
}
else if (attributes.containsKey(SHAPE_KEY)) {
AttrValue shape = attributes.get(SHAPE_KEY);
long[] shapeArr = getShapeFromAttr(shape);
int dims = shapeArr.length;
if (dims > 0) {
// even vector is 2d in nd4j
if (dims == 1)
dimensions.add(1L);
for (int e = 0; e < dims; e++) {
// TODO: eventually we want long shapes :(
dimensions.add(getShapeFromAttr(shape)[e]);
}
}
}
}
else if(isPlaceHolder(tfNode)) {
SDVariable var = diff.getVariable(getName(tfNode));
Preconditions.checkState(var.isPlaceHolder(), "Variable should be marked as placeholder at this point: %s", var);
} else {
val opName = tfNode.getOp();
if(importOverride != null){
//First, get inputs:
int numInputs = tfNode.getInputCount();
List inputs = new ArrayList<>(numInputs);
List controlDeps = null;
for( int i=0; i this)
if (v == null) {
//Check 'op skip' edge case
boolean shouldSkip = false;
if(opFilter != null){
//Get the input node
List l = importState.getGraph().getNodeList();
NodeDef inputNodeDef = null;
for(NodeDef nd : l){
if(inName.equals(nd.getName())){
inputNodeDef = nd;
break;
}
}
Preconditions.checkState(inputNodeDef != null, "Could not find node with name \"%s\"", inName);
shouldSkip = true;
}
if(!shouldSkip) {
//First: try to work out the datatype of this input node
//Given we haven't already imported it at this point, it must be the 2nd or later output of an op
String inputOpName = varNameToOpName(inName);
NodeDef inputOp = importState.getVariables().get(inputOpName);
int outputIdx = varNameToOpOutputNumber(name);
org.nd4j.linalg.api.buffer.DataType dt = dataTypeForTensor(inputOp, outputIdx);
if (dt == org.nd4j.linalg.api.buffer.DataType.UNKNOWN)
dt = null; //Infer it later
v = diff.var(name, VariableType.ARRAY, null, dt, (long[]) null);
}
}
if(controlDep){
if(controlDeps == null)
controlDeps = new ArrayList<>();
controlDeps.add(v);
} else {
inputs.add(v);
}
}
log.info("Importing op {} using override {}", opName, importOverride);
importOverride.initFromTensorFlow(inputs, controlDeps, tfNode, diff, getAttrMap(tfNode), importState.getGraph());
} else {
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
if (differentialFunction == null) {
throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?");
}
try {
DifferentialFunction newInstance = differentialFunction.getClass().newInstance();
List args = new ArrayList<>();
List controlDeps = null;
newInstance.setOwnName(tfNode.getName());
int x = 0;
for (int i = 0; i < tfNode.getInputCount(); i++) {
String inName = tfNode.getInput(i);
String inputOpName = varNameToOpName(inName);
NodeDef inputNode = importState.getVariables().get(inputOpName);
if (shouldSkip(inputNode) && !inName.endsWith("/read"))
continue;
boolean controlDep = isControlDependency(inName);
String name = getNodeName(inName);
SDVariable v = diff.getVariable(name);
//At this point, all placeholders, variables and constants should have been imported
//This: this should be an array type variable (i.e., activations)
if (v == null) {
//First: try to work out the datatype of this input node
//Given we haven't already imported it at this point, it must be the 2nd or later output of an op
NodeDef inputOp = importState.getVariables().get(inputOpName);
int outputIdx = varNameToOpOutputNumber(name);
org.nd4j.linalg.api.buffer.DataType dt = dataTypeForTensor(inputOp, outputIdx);
if (dt == org.nd4j.linalg.api.buffer.DataType.UNKNOWN)
dt = null; //Infer it later
v = diff.var(name, VariableType.ARRAY, null, dt, (long[]) null);
}
if (controlDep) {
//Is only a control dependency input to op, not a real data input
if (controlDeps == null)
controlDeps = new ArrayList<>();
if (!controlDeps.contains(name))
controlDeps.add(name);
} else {
//Is a standard/"real" op input
args.add(v);
}
}
diff.addArgsFor(args.toArray(new SDVariable[args.size()]), newInstance);
newInstance.setSameDiff(importState.getSameDiff());
if (controlDeps != null) {
SameDiffOp op = diff.getOps().get(newInstance.getOwnName());
op.setControlDeps(controlDeps);
//Also record this on the variables:
for (String s : controlDeps) {
Variable v = diff.getVariables().get(s);
if (v.getControlDepsForOp() == null)
v.setControlDeps(new ArrayList());
List l = v.getControlDepsForOp();
if (!l.contains(op.getName()))
l.add(op.getName());
}
}
newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
//ensure we can track node name to function instance later.
diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
} catch (Exception e) {
log.error("Failed to import op [{}]", opName);
throw new RuntimeException(e);
}
}
}
}
/**
* Calls {@link #initFunctionFromProperties(DifferentialFunction, Map, NodeDef, GraphDef)}
* using {@link DifferentialFunction#tensorflowName()}
* @param on the function to use init on
* @param attributesForNode the attributes for the node
* @param node
* @param graph
*/
public void initFunctionFromProperties(DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) {
initFunctionFromProperties(on.tensorflowName(),on,attributesForNode,node,graph);
}
/**
* Init a function's attributes
* @param mappedTfName the tensorflow name to pick (sometimes ops have multiple names
* @param on the function to map
* @param attributesForNode the attributes for the node
* @param node
* @param graph
*/
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) {
val properties = on.mappingsForFunction();
val tfProperties = properties.get(mappedTfName);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
val attributeAdapters = on.attributeAdaptersForFunction();
// if there's no properties announced for this function - just return
if (tfProperties == null)
return;
//Can't execute in just any order: sometimes there are dependencies between attribute mappings
//For example, conv2d strides depend on data format -> need to map data format before mapping strides
//Solution: map nodes without adapters before nodes with adapters. This doesn't guarantee we'll always be
// mapping in the right order (for example, we might have adapter(x) depends on adapter(y)) but it should catch most cases
Map map;
if(attributeAdapters == null || !attributeAdapters.containsKey(mappedTfName)) {
map = tfProperties;
} else {
map = new LinkedHashMap<>();
for (Map.Entry e : tfProperties.entrySet()) {
if (!attributeAdapters.get(mappedTfName).containsKey(e.getKey())) {
//No adapter for this attribute
map.put(e.getKey(), e.getValue());
}
}
for (Map.Entry e : tfProperties.entrySet()) {
if (!map.containsKey(e.getKey())) {
//Not added on first pass -> must have attribute mapper
map.put(e.getKey(), e.getValue());
}
}
}
for(Map.Entry entry : map.entrySet()){
val tfAttrName = entry.getValue().getTfAttrName();
val currentField = fields.get(entry.getKey());
AttributeAdapter adapter = null;
if(attributeAdapters != null && !attributeAdapters.isEmpty()) {
val mappers = attributeAdapters.get(mappedTfName);
val adapterFor = mappers.get(entry.getKey());
adapter = adapterFor;
}
if(tfAttrName != null) {
if(currentField == null) {
continue;
}
if(attributesForNode.containsKey(tfAttrName)) {
val attr = attributesForNode.get(tfAttrName);
switch (attr.getValueCase()) {
case B:
if (adapter != null) {
adapter.mapAttributeFor(attr.getB(), currentField, on);
}
break;
case F: break;
case FUNC: break;
case S:
val setString = attr.getS().toStringUtf8();
if(adapter != null) {
adapter.mapAttributeFor(setString,currentField,on);
}
else
on.setValueFor(currentField,setString);
break;
case I:
val setInt = (int) attr.getI();
if(adapter != null) {
adapter.mapAttributeFor(setInt,currentField,on);
}
else
on.setValueFor(currentField,setInt);
break;
case SHAPE:
val shape = attr.getShape().getDimList();
int[] dimsToSet = new int[shape.size()];
for(int i = 0; i < dimsToSet.length; i++) {
dimsToSet[i] = (int) shape.get(i).getSize();
}
if(adapter != null) {
adapter.mapAttributeFor(dimsToSet,currentField,on);
}
else
on.setValueFor(currentField,dimsToSet);
break;
case VALUE_NOT_SET:break;
case PLACEHOLDER: break;
case LIST:
val setList = attr.getList();
if(!setList.getIList().isEmpty()) {
val intList = Ints.toArray(setList.getIList());
if(adapter != null) {
adapter.mapAttributeFor(intList,currentField,on);
}
else
on.setValueFor(currentField,intList);
}
else if(!setList.getBList().isEmpty()) {
break;
}
else if(!setList.getFList().isEmpty()) {
val floats = Floats.toArray(setList.getFList());
if(adapter != null) {
adapter.mapAttributeFor(floats,currentField,on);
}
else
on.setValueFor(currentField,floats);
break;
}
else if(!setList.getFuncList().isEmpty()) {
break;
}
else if(!setList.getTensorList().isEmpty()) {
break;
}
break;
case TENSOR:
val tensorToGet = TFGraphMapper.getInstance().mapTensorProto(attr.getTensor());
if(adapter != null) {
adapter.mapAttributeFor(tensorToGet,currentField,on);
}
else
on.setValueFor(currentField,tensorToGet);
break;
case TYPE:
if (adapter != null) {
adapter.mapAttributeFor(attr.getType(), currentField, on);
}
break;
}
}
}
else if(entry.getValue().getTfInputPosition() != null) {
int position = entry.getValue().getTfInputPosition();
if(position < 0) {
position += node.getInputCount();
}
val inputFromNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,node.getInput(position));
INDArray tensor = inputFromNode != null ? TFGraphMapper.getInstance().getNDArrayFromTensor("value",inputFromNode,graph) : null;
if(tensor == null) {
tensor = on.getSameDiff().getArrForVarName(getNodeName(node.getInput(position)));
}
if(tensor != null) {
//use adapter instead of direct mapping just like above
if(adapter != null) {
adapter.mapAttributeFor(tensor,currentField,on);
}
else {
if(currentField.getType().equals(int[].class)) {
on.setValueFor(currentField,tensor.data().asInt());
}
else if(currentField.getType().equals(double[].class)) {
on.setValueFor(currentField,tensor.data().asDouble());
}
else if(currentField.getType().equals(float[].class)) {
on.setValueFor(currentField,tensor.data().asFloat());
}
else if(currentField.getType().equals(INDArray.class)) {
on.setValueFor(currentField,tensor);
}
else if(currentField.getType().equals(int.class)) {
on.setValueFor(currentField,tensor.getInt(0));
}
else if(currentField.getType().equals(double.class)) {
on.setValueFor(currentField,tensor.getDouble(0));
}
else if(currentField.getType().equals(float.class)) {
on.setValueFor(currentField,tensor.getFloat(0));
}
}
} else {
on.getSameDiff().addPropertyToResolve(on,entry.getKey());
}
}
}
}
@Override
public org.nd4j.linalg.api.buffer.DataType dataTypeForTensor(NodeDef tensorProto, int outNum) {
//First: work out what attribute we should be looking at to determine output type
String opName = tensorProto.getOp();
OpDef opDef = TensorflowDescriptorParser.opDescs().get(opName);
org.tensorflow.framework.DataType tfType;
/*
Note: "OutputArgCount" doesn't account for repeated or variable length outputs.
For example: "ShapeN" op has 1 for getOutputArgCount but output_arg has attribute "number_attr" that specifies number
of actual variables
*/
int outputArgCount = opDef == null ? 0 : opDef.getOutputArgCount();
int[] outVarsPerOutputArg = outputArgCount == 0 ? null : new int[outputArgCount];
int actualOutputCount = 0;
if(outputArgCount > 0){
for(int i=0; i 0){ //Looks like a few OpDef instances have outputs but don't actually list them... example: NoOp
Preconditions.checkState(outNum < actualOutputCount, "Cannot get output argument %s from op %s with %s output variables - variable %s", outNum, actualOutputCount, tensorProto.getName(), tensorProto.getName());
int argIdx = outNum;
if(outputArgCount != actualOutputCount){
//Map backwards accunting for fact that each output arg might correspond to multiple variables: for output variable x, which argument is this?
int idx = 0;
int soFar = 0;
while(soFar + outVarsPerOutputArg[idx] <= outNum){
soFar += outVarsPerOutputArg[idx++];
}
argIdx = idx;
}
OpDef.ArgDef argDef = opDef.getOutputArg(argIdx);
String typeAttr = argDef.getTypeAttr();
if(typeAttr != null && tensorProto.containsAttr(typeAttr)){
tfType = tensorProto.getAttrOrThrow(typeAttr).getType();
} else {
return org.nd4j.linalg.api.buffer.DataType.UNKNOWN;
}
} else {
if(tensorProto.getOp().equals("NoOp")){
return org.nd4j.linalg.api.buffer.DataType.UNKNOWN;
}
log.warn("No TensorFlow descriptor found for tensor \"{}\", op \"{}\"", tensorProto.getName(), tensorProto.getOp());
//No descriptor... try to fall back on common type attribute names
if(!tensorProto.containsAttr("dtype") && !tensorProto.containsAttr("Tidx") && !tensorProto.containsAttr("T"))
return org.nd4j.linalg.api.buffer.DataType.UNKNOWN;
tfType = tensorProto.containsAttr("dtype") ? tensorProto.getAttrOrThrow("dtype").getType()
: tensorProto.containsAttr("T") ? tensorProto.getAttrOrThrow("T").getType() : tensorProto
.getAttrOrThrow("Tidx").getType();
}
return convertType(tfType);
}
public static org.nd4j.linalg.api.buffer.DataType convertType(org.tensorflow.framework.DataType tfType){
switch(tfType) {
case DT_DOUBLE: return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
case DT_FLOAT: return org.nd4j.linalg.api.buffer.DataType.FLOAT;
case DT_HALF: return org.nd4j.linalg.api.buffer.DataType.HALF;
case DT_BFLOAT16: return org.nd4j.linalg.api.buffer.DataType.HALF;
case DT_INT8: return org.nd4j.linalg.api.buffer.DataType.BYTE;
case DT_INT16: return org.nd4j.linalg.api.buffer.DataType.SHORT;
case DT_INT32: return org.nd4j.linalg.api.buffer.DataType.INT;
case DT_INT64: return org.nd4j.linalg.api.buffer.DataType.LONG;
case DT_UINT8: return org.nd4j.linalg.api.buffer.DataType.UBYTE;
case DT_STRING: return org.nd4j.linalg.api.buffer.DataType.UTF8;
case DT_BOOL: return org.nd4j.linalg.api.buffer.DataType.BOOL;
default: return org.nd4j.linalg.api.buffer.DataType.UNKNOWN;
}
}
@Override
public boolean isStringType(NodeDef tensorProto){
DataType dt = null;
if(tensorProto.containsAttr("dtype")){
dt = tensorProto.getAttrOrThrow("dtype").getType();
} else if(tensorProto.containsAttr("T")){
dt = tensorProto.getAttrOrThrow("T").getType();
} else if(tensorProto.containsAttr("Tidx")){
dt = tensorProto.getAttrOrThrow("Tidx").getType();
}
return dt == DataType.DT_STRING || dt == DataType.DT_STRING_REF;
}
@Override
public String getAttrValueFromNode(NodeDef nodeDef, String key) {
return nodeDef.getAttrOrThrow(key).getS().toStringUtf8();
}
@Override
public long[] getShapeFromAttribute(AttrValue attrValue) {
TensorShapeProto shape = attrValue.getShape();
long[] ret = new long[shape.getDimCount()];
for(int i = 0; i < ret.length; i++) {
ret[i] = (int) shape.getDim(i).getSize();
}
return ret;
}
@Override
public boolean isPlaceHolder(NodeDef nodeDef) {
return nodeDef.getOp().startsWith("Placeholder");
}
@Override
public boolean isConstant(NodeDef nodeDef) {
return nodeDef.getOp().startsWith("Const");
}
@Override
public List getControlDependencies(NodeDef node){
int numInputs = node.getInputCount();
if(numInputs == 0)
return null;
List out = null;
for( int i=0; i();
out.add(getNodeName(in)); //Remove "^" prefix
}
}
return out;
}
@Override
public INDArray getNDArrayFromTensor(String tensorName, NodeDef node, GraphDef graph) {
//placeholder of some kind
if(!node.getAttrMap().containsKey("value")) {
return null;
}
val tfTensor = node.getAttrOrThrow("value").getTensor();
return mapTensorProto(tfTensor);
}
public INDArray mapTensorProto(TensorProto tfTensor) {
// building shape first
int dims = tfTensor.getTensorShape().getDimCount();
long[] arrayShape = null;
List dimensions = new ArrayList<>();
for (int e = 0; e < dims; e++) {
// TODO: eventually we want long shapes :(
int dim = (int) tfTensor.getTensorShape().getDim(e).getSize();
dimensions.add(dim);
}
arrayShape = ArrayUtil.toLongArray(Ints.toArray(dimensions));
if (tfTensor.getDtype() == DataType.DT_INT32 || tfTensor.getDtype() == DataType.DT_INT16 || tfTensor.getDtype() == DataType.DT_INT8) {
// valueOf
if (tfTensor.getIntValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if(tfTensor.getIntValCount() < 1)
return Nd4j.scalar( ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), 0);
//should be scalar otherwise
int val = tfTensor.getIntVal(0);
if (arrayShape == null || arrayShape.length == 0)
return Nd4j.scalar( ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), val);
return Nd4j.valueArrayOf(arrayShape, val, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()));
} else if (tfTensor.getInt64ValCount() > 0) {
val jArray = new int[tfTensor.getIntValCount()];
for (int e = 0; e < tfTensor.getIntValCount(); e++) {
jArray[e] = tfTensor.getIntVal(e);
}
// TF arrays are always C
return Nd4j.create(Nd4j.createTypedBuffer(jArray, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()));
} else {
// FIXME: INT bytebuffers should be converted to floating point
//throw new UnsupportedOperationException("To be implemented yet");
long length = ArrayUtil.prodLong(arrayShape);
// binary representation
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
val fb = bb.order(ByteOrder.nativeOrder()).asIntBuffer();
val fa = new int[fb.capacity()];
for (int e = 0; e < fb.capacity(); e++)
fa[e] = fb.get(e);
if (fa.length == 0)
return Nd4j.empty(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()));
//throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
if (fa.length == 1)
return Nd4j.scalar(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), fa[0]);
if (arrayShape.length == 1)
return Nd4j.create(fa, new long[]{fa.length}, new long[]{1}, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()));
val array = Nd4j.create(Nd4j.createTypedBuffer(fa, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()));
//log.debug("SUM1: {}", array.sumNumber());
//log.debug("Data: {}", Arrays.toString(array.data().asFloat()));
return array;
}
} else if (tfTensor.getDtype() == DataType.DT_FLOAT) {
if (tfTensor.getFloatValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if(tfTensor.getFloatValCount() < 1)
return Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.FLOAT, 0.0f);
float val = tfTensor.getFloatVal(0);
if (arrayShape == null || arrayShape.length == 0)
arrayShape = new long[]{};
INDArray array = Nd4j.valueArrayOf(arrayShape, val, org.nd4j.linalg.api.buffer.DataType.FLOAT);
return array;
} else if (tfTensor.getFloatValCount() > 0) {
float[] jArray = new float[tfTensor.getFloatValCount()];
for (int e = 0; e < tfTensor.getFloatValCount(); e++) {
jArray[e] = tfTensor.getFloatVal(e);
}
INDArray array = Nd4j.create(Nd4j.createTypedBuffer(jArray, org.nd4j.linalg.api.buffer.DataType.FLOAT), arrayShape, Nd4j.getStrides(arrayShape), 0, 'c');
return array;
} else if (tfTensor.getTensorContent().size() > 0){
// binary representation
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
val fb = bb.order(ByteOrder.nativeOrder()).asFloatBuffer();
val fa = new float[fb.capacity()];
for (int e = 0; e < fb.capacity(); e++)
fa[e] = fb.get(e);
if (fa.length == 0)
throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
if (fa.length == 1)
return Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.FLOAT, fa[0]);
if (arrayShape.length == 1)
return Nd4j.create(fa, new long[]{fa.length}, new long[]{1}, 'c', org.nd4j.linalg.api.buffer.DataType.FLOAT);
val array = Nd4j.create(fa, arrayShape, Nd4j.getStrides(arrayShape, 'c'), 'c', org.nd4j.linalg.api.buffer.DataType.FLOAT);
return array;
}
} else if (tfTensor.getDtype() == DataType.DT_DOUBLE) {
if (tfTensor.getDoubleValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if(tfTensor.getDoubleValCount() < 1)
return Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.DOUBLE, 0.0);
double val = tfTensor.getDoubleVal(0);
INDArray array = Nd4j.trueScalar(val);
return array;
} else if (tfTensor.getDoubleValCount() > 0) {
val jArray = new double[tfTensor.getDoubleValCount()];
for (int e = 0; e < tfTensor.getDoubleValCount(); e++) {
jArray[e] = tfTensor.getDoubleVal(e);
}
// TF arrays are always C
val array = Nd4j.create(jArray, arrayShape, Nd4j.getStrides(arrayShape, 'c'), 'c', org.nd4j.linalg.api.buffer.DataType.DOUBLE);
return array;
} else if (tfTensor.getTensorContent().size() > 0) {
// binary representation
//DataBuffer buffer = Nd4j.createBuffer(tfTensor.getTensorContent().asReadOnlyByteBuffer(), DataType.FLOAT, (int) length);
//INDArray array = Nd4j.createArrayFromShapeBuffer(buffer, Nd4j.getShapeInfoProvider().createShapeInformation(arrayShape, 'c'));
// binary representation
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
val fb = bb.order(ByteOrder.nativeOrder()).asDoubleBuffer();
val da = new double[fb.capacity()];
for (int e = 0; e < fb.capacity(); e++)
da[e] = fb.get(e);
if (da.length == 0)
throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
if (da.length == 1)
return Nd4j.trueScalar(da[0]);
if (arrayShape.length == 1)
return Nd4j.trueVector(da);
val array = Nd4j.create(da, arrayShape, 0, 'c');
return array;
}
} else if (tfTensor.getDtype() == DataType.DT_INT64) {
if (tfTensor.getInt64ValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if (tfTensor.getInt64ValCount() < 1)
return Nd4j.trueScalar(0.0);
double val = (double) tfTensor.getInt64Val(0);
INDArray array = Nd4j.trueScalar(val);
return array;
} else if (tfTensor.getInt64ValCount() > 0) {
val jArray = new long[tfTensor.getInt64ValCount()];
for (int e = 0; e < tfTensor.getInt64ValCount(); e++) {
jArray[e] = tfTensor.getInt64Val(e);
}
// TF arrays are always C
INDArray array = Nd4j.create(Nd4j.createTypedBuffer(jArray, org.nd4j.linalg.api.buffer.DataType.LONG), arrayShape, Nd4j.getStrides(arrayShape, 'c'),0, 'c', org.nd4j.linalg.api.buffer.DataType.LONG);
return array;
} else if (tfTensor.getTensorContent().size() > 0) {
//throw new UnsupportedOperationException("To be implemented yet");
//Mapping INT bytebuffers should be converted to floating point
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
val lb = bb.order(ByteOrder.nativeOrder()).asLongBuffer();
val fa = new long[lb.capacity()];
for (int e = 0; e < lb.capacity(); e++)
fa[e] = lb.get(e);
if (fa.length == 0)
throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
if (fa.length == 1)
return Nd4j.trueScalar(fa[0]);
if (arrayShape.length == 1)
return Nd4j.trueVector(fa);
val array = Nd4j.create(Nd4j.createTypedBuffer(fa, org.nd4j.linalg.api.buffer.DataType.LONG), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', org.nd4j.linalg.api.buffer.DataType.LONG);
return array;
}
} else if (tfTensor.getDtype() == DataType.DT_BOOL) {
if (tfTensor.getBoolValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if (tfTensor.getBoolValCount() < 1)
return Nd4j.scalar(false);
val val = tfTensor.getBoolVal(0);
val arr = Nd4j.scalar(val);
return arr;
} else if (tfTensor.getBoolValCount() > 0) {
val jArray = new boolean[tfTensor.getBoolValCount()];
for (int e = 0; e < tfTensor.getBoolValCount(); e++) {
jArray[e] = tfTensor.getBoolVal(e);
}
// TF arrays are always C
INDArray array = Nd4j.create(Nd4j.createTypedBuffer(jArray, org.nd4j.linalg.api.buffer.DataType.BOOL), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', org.nd4j.linalg.api.buffer.DataType.BOOL);
return array;
} else if (tfTensor.getTensorContent().size() > 0) {
throw new UnsupportedOperationException("Not yet implemented for DataType.DT_BOOL");
}
} else if(tfTensor.getDtype() == DataType.DT_STRING){
if (tfTensor.getStringValCount() <= 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if (tfTensor.getStringValCount() < 1)
return Nd4j.empty(org.nd4j.linalg.api.buffer.DataType.UTF8);
String val = tfTensor.getStringVal(0).toStringUtf8();
INDArray arr = Nd4j.scalar(val);
return arr;
} else if (tfTensor.getStringValCount() > 0) {
String[] sArr = new String[tfTensor.getStringValCount()];
for (int e = 0; e < sArr.length; e++) {
sArr[e] = tfTensor.getStringVal(e).toStringUtf8();
}
// TF arrays are always C
INDArray array = Nd4j.create(sArr).reshape(arrayShape);
return array;
}
} else {
throw new UnsupportedOperationException("Unknown dataType found: [" + tfTensor.getDtype() + "]");
}
throw new ND4JIllegalStateException("Invalid method state");
}
@Override
public long[] getShapeFromTensor(NodeDef tensorProto) {
if(tensorProto.containsAttr("shape")) {
return shapeFromShapeProto(tensorProto.getAttrOrThrow("shape").getShape());
}
//yet to be determined shape, or tied to an op where output shape is dynamic
else if(!tensorProto.containsAttr("value")) {
return null;
}
else
return shapeFromShapeProto(tensorProto.getAttrOrThrow("value").getTensor().getTensorShape());
}
@Override
public Set opsToIgnore() {
return graphMapper;
}
@Override
public String getInputFromNode(NodeDef node, int index) {
return node.getInput(index);
}
@Override
public int numInputsFor(NodeDef nodeDef) {
return nodeDef.getInputCount();
}
private long[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) {
long[] shape = new long[tensorShapeProto.getDimList().size()];
for(int i = 0; i < shape.length; i++) {
shape[i] = tensorShapeProto.getDim(i).getSize();
}
return shape;
}
/**
* Returns the node for an if statement
* @param from the starting node (a merge node that represents a conditional)
* @param graph the graph to search
* @return an import state representing the nodes for each scope
*/
public IfImportState nodesForIf(NodeDef from, GraphDef graph) {
//Assume we start with a switch statement
int currNodeIndex = graph.getNodeList().indexOf(from);
val trueDefName = from.getInput(1);
val falseDefName = from.getInput(0);
val scopeId = UUID.randomUUID().toString();
val scopeName = scopeId + "-" + trueDefName.substring(0,trueDefName.indexOf("/"));
val trueDefScopeName = scopeName + "-true-scope";
val falseDefScopeName = scopeName + "-false-scope";
boolean onFalseDefinition = true;
//start with the true
boolean onTrueDefinition = false;
List falseBodyNodes = new ArrayList<>();
List trueBodyNodes = new ArrayList<>();
List conditionNodes = new ArrayList<>();
Set seenNames = new LinkedHashSet<>();
/**
* Accumulate a list backwards to get proper ordering.
*
*/
for(int i = currNodeIndex; i >= 0; i--) {
//switch to false names
if(graph.getNode(i).getName().equals(trueDefName)) {
onFalseDefinition = false;
onTrueDefinition = true;
}
//on predicate now
if(graph.getNode(i).getName().contains("pred_id")) {
onTrueDefinition = false;
}
//don't readd the same node, this causes a stackoverflow
if(onTrueDefinition && !graph.getNode(i).equals(from)) {
trueBodyNodes.add(graph.getNode(i));
}
else if(onFalseDefinition && !graph.getNode(i).equals(from)) {
falseBodyNodes.add(graph.getNode(i));
}
//condition scope now
else {
val currNode = graph.getNode(i);
if(currNode.equals(from))
continue;
//break only after bootstrapping the first node (the predicate id node)
if(!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) {
break;
}
/**
* Continuously add inputs seen for each node in the sub graph that occurs.
* Starting from the predicate id, any node that has inputs in the condition scope
* are by definition within the scope. Any node not encountered after that is considered out of scope.
* This means we break.
*/
for(int inputIdx = 0; inputIdx < currNode.getInputCount(); inputIdx++) {
seenNames.add(currNode.getInput(inputIdx));
}
//ensure the "current node" is added as well
seenNames.add(graph.getNode(i).getName());
conditionNodes.add(graph.getNode(i));
}
}
/**
* Since we are going over the graph backwards,
* we need to reverse the nodes to ensure proper ordering.
*/
Collections.reverse(falseBodyNodes);
Collections.reverse(trueBodyNodes);
Collections.reverse(conditionNodes);
return IfImportState.builder()
.condNodes(conditionNodes)
.falseNodes(falseBodyNodes)
.trueNodes(trueBodyNodes)
.conditionBodyScopeName(falseDefScopeName)
.falseBodyScopeName(falseDefScopeName)
.trueBodyScopeName(trueDefScopeName)
.conditionBodyScopeName(scopeName)
.build();
}
}