Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.imports.graphmapper.tf.TFGraphMapper Maven / Gradle / Ivy
package org.nd4j.imports.graphmapper.tf;
import com.github.os72.protobuf351.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.BaseGraphMapper;
import org.nd4j.imports.graphmapper.ImportState;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.*;
import java.io.*;
import java.nio.ByteOrder;
import java.util.*;
/**
* Map tensorflow graph protos
* to the intermediate representation
* for samediff.
*
* @author Adam Gibson
*/
@Slf4j
public class TFGraphMapper extends BaseGraphMapper {
private Set seenNodes = new LinkedHashSet<>();
public final static String VALUE_ATTR_KEY = "value";
public final static String SHAPE_KEY = "shape";
private static TFGraphMapper MAPPER_INSTANCE = new TFGraphMapper();
private Set graphMapper = new HashSet(){{
//While and If
//While -> Enter
/**
* Need to work on coping with variables
* that are marked as "shouldSkip"
*
* Possibly consider replacing should skip
* with a special handler interface. Something like
*
* public interface ImportOpHandler
*/
add("LoopCond");
/**
* We should skip this for the sake of while..but not if.
* Need to be a bit more flexible here.
*/
add("Merge");
add("Exit");
add("NextIteration");
add("NoOp");
add("Switch");
}};
//singleton
private TFGraphMapper() {}
/**
* Singleton. Get the needed instance.
* @return
*/
public static TFGraphMapper getInstance() {
return MAPPER_INSTANCE;
}
@Override
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
try {
GraphDef graphDef = GraphDef.parseFrom(inputFile);
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(NodeDef node : graphDef.getNodeList()) {
bufferedWriter.write(node.toString());
}
bufferedWriter.flush();
bufferedWriter.close();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public boolean isOpIgnoreException(NodeDef node) {
//if statements should not be ignored
/*
if(node.getOp().equals("Merge")) {
boolean ret = false;
for(int i = 0; i < node.getInputCount(); i++) {
//while loop
ret = ret || !node.getInput(i).endsWith("/Enter") || !node.getInput(i).endsWith("/NextIteration");
}
return ret;
}
else if(node.getOp().equals("Switch")) {
boolean ret = false;
for(int i = 0; i < node.getInputCount(); i++) {
//while loop
ret = ret || !node.getInput(i).endsWith("/Merge") || !node.getInput(i).endsWith("/LoopCond");
}
return ret;
}
*/
return true;
}
@Override
public String getTargetMappingForOp(DifferentialFunction function, NodeDef node) {
return function.opName();
}
@Override
public NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) {
for(int i = 0; i < graph.getNodeCount(); i++) {
val node = graph.getNode(i);
if(node.getName().equals(name))
return node;
}
return null;
}
@Override
public void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map> propertyMappingsForFunction) {
if(node == null) {
throw new ND4JIllegalStateException("No node found for name " + name);
}
val mapping = propertyMappingsForFunction.get(getOpType(node)).get(name);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
if(mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) {
int tfMappingIdx = mapping.getTfInputPosition();
if(tfMappingIdx < 0)
tfMappingIdx += node.getInputCount();
val input = node.getInput(tfMappingIdx);
val inputNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,input);
INDArray arr = getArrayFrom(inputNode,graph);
if(arr == null) {
arr = sameDiff.getArrForVarName(input);
}
if(arr == null && inputNode != null) {
sameDiff.addPropertyToResolve(on,name);
sameDiff.addVariableMappingForField(on,name,inputNode.getName());
return;
}
else if(inputNode == null) {
sameDiff.addAsPlaceHolder(input);
return;
}
val field = fields.get(name);
val type = field.getType();
if(type.equals(int[].class)) {
on.setValueFor(field,arr.data().asInt());
}
else if(type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) {
if(mapping.getShapePosition() != null) {
on.setValueFor(field,arr.size(mapping.getShapePosition()));
}
else
on.setValueFor(field,arr.getInt(0));
}
else if(type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) {
on.setValueFor(field,arr.getDouble(0));
}
}
else {
val tfMappingAttrName = mapping.getTfAttrName();
if(tfMappingAttrName == null) {
return;
}
if(!node.containsAttr(tfMappingAttrName)) {
return;
}
val attr = node.getAttrOrThrow(tfMappingAttrName);
val type = attr.getType();
if(fields == null) {
throw new ND4JIllegalStateException("No fields found for op " + mapping);
}
if(mapping.getPropertyNames() == null) {
throw new ND4JIllegalStateException("no property found for " + name + " and op " + on.opName());
}
val field = fields.get(mapping.getPropertyNames()[0]);
Object valueToSet = null;
switch(type) {
case DT_BOOL:
valueToSet = attr.getB();
break;
case DT_INT8:
valueToSet = attr.getI();
break;
case DT_INT16:
valueToSet = attr.getI();
break;
case DT_INT32:
valueToSet = attr.getI();
break;
case DT_FLOAT:
valueToSet = attr.getF();
break;
case DT_DOUBLE:
valueToSet = attr.getF();
break;
case DT_STRING:
valueToSet = attr.getS();
break;
case DT_INT64:
valueToSet = attr.getI();
break;
}
if(field != null && valueToSet != null)
on.setValueFor(field,valueToSet);
}
}
/**
* {@inheritDoc}
*/
@Override
public boolean isPlaceHolderNode(NodeDef node) {
return node.getOp().startsWith("Placeholder");
}
/**
* {@inheritDoc}
*/
@Override
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
try {
GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(NodeDef node : graphDef.getNodeList()) {
bufferedWriter.write(node.toString());
}
bufferedWriter.flush();
bufferedWriter.close();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public int[] getShapeFromAttr(AttrValue attr) {
return shapeFromShapeProto(attr.getShape());
}
@Override
public Map getAttrMap(NodeDef nodeDef) {
return nodeDef.getAttrMap();
}
@Override
public String getName(NodeDef nodeDef) {
return nodeDef.getName();
}
@Override
public boolean alreadySeen(NodeDef nodeDef) {
return seenNodes.contains(nodeDef.getName());
}
@Override
public boolean isVariableNode(NodeDef nodeDef) {
boolean isVar = nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const");
return isVar;
}
@Override
public boolean shouldSkip(NodeDef opType) {
if(opType == null)
return true;
boolean endsWithRead = opType.getName().endsWith("/read");
boolean isReductionIndices = opType.getOp().endsWith("/reduction_indices");
return endsWithRead || isReductionIndices;
}
@Override
public boolean hasShape(NodeDef nodeDef) {
return nodeDef.containsAttr(SHAPE_KEY);
}
@Override
public int[] getShape(NodeDef nodeDef) {
return getShapeFromAttr(nodeDef.getAttrOrThrow(SHAPE_KEY));
}
@Override
public INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) {
if(nodeDef == null) {
return null;
}
return getNDArrayFromTensor(nodeDef.getName(),nodeDef, graph);
}
@Override
public String getOpType(NodeDef nodeDef) {
return nodeDef.getOp();
}
/**
*
* @param graphDef
* @return
*/
@Override
public List getNodeList(GraphDef graphDef) {
return graphDef.getNodeList();
}
/**
*
* @param name the tensorflow or onnx name
* @return
*/
@Override
public DifferentialFunction getMappedOp(String name) {
return DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(name);
}
/**
* Map a tensorflow node name
* to the samediff equivalent
* for import
* @param name the name to change
* @return the input tensorflow name
*/
public String getNodeName(String name) {
//tensorflow adds colons to the end of variables representing input index, this strips those off
String ret = name;
if(ret.startsWith("^"))
ret = ret.substring(1);
if(ret.endsWith("/read")) {
ret = ret.replace("/read","");
}
return ret;
}
@Override
public Map variablesForGraph(GraphDef graphDef) {
Map ret = new LinkedHashMap<>();
for(NodeDef nodeDef : graphDef.getNodeList()) {
if(nodeDef.getName().endsWith("/read")) {
continue;
}
val name = translateToSameDiffName(nodeDef.getName(), nodeDef);
ret.put(name,nodeDef);
}
return ret;
}
@Override
public String translateToSameDiffName(String name, NodeDef node) {
if(isVariableNode(node) || isPlaceHolder(node)) {
return name;
}
StringBuilder stringBuilder = new StringBuilder();
//strip arg number
if(name.contains(":")) {
name = name.substring(0,name.lastIndexOf(':'));
stringBuilder.append(name);
}
else {
stringBuilder.append(name);
}
return stringBuilder.toString();
}
@Override
public Message.Builder getNewGraphBuilder() {
return GraphDef.newBuilder();
}
@Override
public GraphDef parseGraphFrom(byte[] inputStream) throws IOException {
return GraphDef.parseFrom(inputStream);
}
@Override
public GraphDef parseGraphFrom(InputStream inputStream) throws IOException {
return GraphDef.parseFrom(inputStream);
}
protected void importCondition(String conditionName, NodeDef tfNode, ImportState importState) {
/**
* Cond structure:
*
*/
}
@Override
public void mapNodeType(NodeDef tfNode, ImportState importState) {
if (shouldSkip(tfNode) || alreadySeen(tfNode) || isVariableNode(tfNode)) {
return;
}
val diff = importState.getSameDiff();
if (isVariableNode(tfNode)) {
List dimensions = new ArrayList<>();
Map attributes = getAttrMap(tfNode);
if (attributes.containsKey(VALUE_ATTR_KEY)) {
diff.var(getName(tfNode),getArrayFrom(tfNode,importState.getGraph()));
}
else if (attributes.containsKey(SHAPE_KEY)) {
AttrValue shape = attributes.get(SHAPE_KEY);
int[] shapeArr = getShapeFromAttr(shape);
int dims = shapeArr.length;
if (dims > 0) {
// even vector is 2d in nd4j
if (dims == 1)
dimensions.add(1);
for (int e = 0; e < dims; e++) {
// TODO: eventually we want long shapes :(
dimensions.add(getShapeFromAttr(shape)[e]);
}
}
}
}
else if(isPlaceHolder(tfNode)) {
val vertexId = diff.getVariable(getName(tfNode));
diff.addAsPlaceHolder(vertexId.getVarName());
}
else {
val opName = tfNode.getOp();
val nodeName = tfNode.getName();
// FIXME: early draft
// conditional import
/*
if (nodeName.startsWith("cond") && nodeName.contains("/")) {
val str = nodeName.replaceAll("/.*$","");
importCondition(str, tfNode, importState);
seenNodes.add(nodeName);
return;
} else if (nodeName.startsWith("while")) {
// while loop import
return;
}
*/
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
if(differentialFunction == null) {
throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?");
}
try {
val newInstance = differentialFunction.getClass().newInstance();
val args = new SDVariable[tfNode.getInputCount()];
newInstance.setOwnName(tfNode.getName());
for(int i = 0; i < tfNode.getInputCount(); i++) {
val name = getNodeName(tfNode.getInput(i));
args[i] = diff.getVariable(name);
if(args[i] == null) {
args[i] = diff.var(name,null,new ZeroInitScheme('f'));
diff.addAsPlaceHolder(args[i].getVarName());
}
/**
* Note here that we are associating
* the output/result variable
* with its inputs and notifying
* the variable that it has a place holder argument
* it should resolve before trying to execute
* anything.
*/
if(diff.isPlaceHolder( args[i].getVarName())) {
diff.putPlaceHolderForVariable(args[i].getVarName(), name);
}
}
diff.addArgsFor(args,newInstance);
newInstance.setSameDiff(importState.getSameDiff());
newInstance.initFromTensorFlow(tfNode,diff,getAttrMap(tfNode),importState.getGraph());
mapProperties(newInstance,tfNode,importState.getGraph(),importState.getSameDiff(),newInstance.mappingsForFunction());
importState.getSameDiff().putFunctionForId(newInstance.getOwnName(),newInstance);
//ensure we can track node name to function instance later.
diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance);
diff.addVarNameForImport(tfNode.getName());
} catch (Exception e) {
log.error("Failed with [{}]", opName);
throw new RuntimeException(e);
}
}
}
/**
* Calls {@link #initFunctionFromProperties(DifferentialFunction, Map, NodeDef, GraphDef)}
* using {@link DifferentialFunction#tensorflowName()}
* @param on the function to use init on
* @param attributesForNode the attributes for the node
* @param node
* @param graph
*/
public void initFunctionFromProperties(DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) {
initFunctionFromProperties(on.tensorflowName(),on,attributesForNode,node,graph);
}
/**
* Init a function's attributes
* @param mappedTfName the tensorflow name to pick (sometimes ops have multiple names
* @param on the function to map
* @param attributesForNode the attributes for the node
* @param node
* @param graph
*/
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map attributesForNode, NodeDef node, GraphDef graph) {
val properties = on.mappingsForFunction();
val tfProperties = properties.get(mappedTfName);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
val attributeAdapters = on.attributeAdaptersForFunction();
// if there's no properties announced for this function - just return
if (tfProperties == null)
return;
for(val entry : tfProperties.entrySet()) {
val tfAttrName = entry.getValue().getTfAttrName();
val currentField = fields.get(entry.getKey());
AttributeAdapter adapter = null;
if(attributeAdapters != null && !attributeAdapters.isEmpty()) {
val mappers = attributeAdapters.get(mappedTfName);
val adapterFor = mappers.get(entry.getKey());
adapter = adapterFor;
}
if(tfAttrName != null) {
if(currentField == null) {
continue;
}
if(attributesForNode.containsKey(tfAttrName)) {
val attr = attributesForNode.get(tfAttrName);
switch (attr.getValueCase()) {
case B:
if (adapter != null) {
adapter.mapAttributeFor(attr.getB(), currentField, on);
}
break;
case F: break;
case FUNC: break;
case S:
val setString = attr.getS().toStringUtf8();
if(adapter != null) {
adapter.mapAttributeFor(setString,currentField,on);
}
else
on.setValueFor(currentField,setString);
break;
case I:
val setInt = (int) attr.getI();
if(adapter != null) {
adapter.mapAttributeFor(setInt,currentField,on);
}
else
on.setValueFor(currentField,setInt);
break;
case SHAPE:
val shape = attr.getShape().getDimList();
int[] dimsToSet = new int[shape.size()];
for(int i = 0; i < dimsToSet.length; i++) {
dimsToSet[i] = (int) shape.get(i).getSize();
}
if(adapter != null) {
adapter.mapAttributeFor(dimsToSet,currentField,on);
}
else
on.setValueFor(currentField,dimsToSet);
break;
case VALUE_NOT_SET:break;
case PLACEHOLDER: break;
case LIST:
val setList = attr.getList();
if(!setList.getIList().isEmpty()) {
val intList = Ints.toArray(setList.getIList());
if(adapter != null) {
adapter.mapAttributeFor(intList,currentField,on);
}
else
on.setValueFor(currentField,intList);
}
else if(!setList.getBList().isEmpty()) {
break;
}
else if(!setList.getFList().isEmpty()) {
val floats = Floats.toArray(setList.getFList());
if(adapter != null) {
adapter.mapAttributeFor(floats,currentField,on);
}
else
on.setValueFor(currentField,floats);
break;
}
else if(!setList.getFuncList().isEmpty()) {
break;
}
else if(!setList.getTensorList().isEmpty()) {
break;
}
break;
case TENSOR:
val tensorToGet = TFGraphMapper.getInstance().mapTensorProto(attr.getTensor());
if(adapter != null) {
adapter.mapAttributeFor(tensorToGet,currentField,on);
}
else
on.setValueFor(currentField,tensorToGet);
break;
case TYPE:
if (adapter != null) {
adapter.mapAttributeFor(attr.getType(), currentField, on);
}
break;
}
}
}
else if(entry.getValue().getTfInputPosition() != null) {
int position = entry.getValue().getTfInputPosition();
if(position < 0) {
position += node.getInputCount();
}
val inputFromNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,node.getInput(position));
INDArray tensor = inputFromNode != null ? TFGraphMapper.getInstance().getNDArrayFromTensor("value",inputFromNode,graph) : null;
if(tensor == null) {
tensor = on.getSameDiff().getArrForVarName(getNodeName(node.getInput(position)));
}
if(tensor != null) {
//use adapter instead of direct mapping just like above
if(adapter != null) {
adapter.mapAttributeFor(tensor,currentField,on);
}
else {
if(currentField.getType().equals(int[].class)) {
on.setValueFor(currentField,tensor.data().asInt());
}
else if(currentField.getType().equals(double[].class)) {
on.setValueFor(currentField,tensor.data().asDouble());
}
else if(currentField.getType().equals(float[].class)) {
on.setValueFor(currentField,tensor.data().asFloat());
}
else if(currentField.getType().equals(INDArray.class)) {
on.setValueFor(currentField,tensor);
}
else if(currentField.getType().equals(int.class)) {
on.setValueFor(currentField,tensor.getInt(0));
}
else if(currentField.getType().equals(double.class)) {
on.setValueFor(currentField,tensor.getDouble(0));
}
else if(currentField.getType().equals(float.class)) {
on.setValueFor(currentField,tensor.getFloat(0));
}
}
}
else {
on.getSameDiff().addPropertyToResolve(on,entry.getKey());
}
}
}
}
@Override
public DataBuffer.Type dataTypeForTensor(NodeDef tensorProto) {
if(!tensorProto.containsAttr("dtype") && !tensorProto.containsAttr("Tidx") && !tensorProto.containsAttr("T"))
return DataBuffer.Type.UNKNOWN;
val type = tensorProto.containsAttr("dtype") ? tensorProto.getAttrOrThrow("dtype").getType()
: tensorProto.containsAttr("T") ? tensorProto.getAttrOrThrow("T").getType() : tensorProto
.getAttrOrThrow("Tidx").getType();
switch(type) {
case DT_DOUBLE: return DataBuffer.Type.DOUBLE;
case DT_INT32:
case DT_INT64: return DataBuffer.Type.INT;
case DT_FLOAT: return DataBuffer.Type.FLOAT;
case DT_BFLOAT16: return DataBuffer.Type.HALF;
default: return DataBuffer.Type.UNKNOWN;
}
}
@Override
public String getAttrValueFromNode(NodeDef nodeDef, String key) {
return nodeDef.getAttrOrThrow(key).getS().toStringUtf8();
}
@Override
public int[] getShapeFromAttribute(AttrValue attrValue) {
TensorShapeProto shape = attrValue.getShape();
int[] ret = new int[shape.getDimCount()];
for(int i = 0; i < ret.length; i++) {
ret[i] = (int) shape.getDim(i).getSize();
}
return ret;
}
@Override
public boolean isPlaceHolder(NodeDef nodeDef) {
return nodeDef.getOp().startsWith("Placeholder");
}
@Override
public INDArray getNDArrayFromTensor(String tensorName, NodeDef node, GraphDef graph) {
//placeholder of some kind
if(!node.getAttrMap().containsKey("value")) {
return null;
}
val tfTensor = node.getAttrOrThrow("value").getTensor();
return mapTensorProto(tfTensor);
}
public INDArray mapTensorProto(TensorProto tfTensor) {
// building shape first
int dims = tfTensor.getTensorShape().getDimCount();
int[] arrayShape = null;
List dimensions = new ArrayList<>();
// we allow vectors now
//if(dims == 1) {
// dimensions.add(1);
// dimensions.add( (int) Math.max(1,tfTensor.getTensorShape().getDim(0).getSize()));
// }
for (int e = 0; e < dims; e++) {
// TODO: eventually we want long shapes :(
int dim = (int) tfTensor.getTensorShape().getDim(e).getSize();
dimensions.add(dim);
}
arrayShape = Ints.toArray(dimensions);
if (tfTensor.getDtype() == DataType.DT_INT32 || tfTensor.getDtype() == DataType.DT_INT16 || tfTensor.getDtype() == DataType.DT_INT8) {
// valueOf
if (tfTensor.getIntValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if(tfTensor.getIntValCount() < 1)
return Nd4j.trueScalar(0.0);
//should be scalar otherwise
int val = tfTensor.getIntVal(0);
if (arrayShape == null || arrayShape.length == 0)
arrayShape = new int[]{};
INDArray array = Nd4j.valueArrayOf(arrayShape, (double) val);
return array;
} else if (tfTensor.getInt64ValCount() > 0) {
double[] jArray = new double[tfTensor.getIntValCount()];
for (int e = 0; e < tfTensor.getIntValCount(); e++) {
jArray[e] = (double) tfTensor.getIntVal(e);
}
// TF arrays are always C
INDArray array = Nd4j.create(jArray, arrayShape, 0, 'c');
return array;
} else {
// FIXME: INT bytebuffers should be converted to floating point
//throw new UnsupportedOperationException("To be implemented yet");
long length = ArrayUtil.prodLong(arrayShape);
// binary representation
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
val fb = bb.order(ByteOrder.nativeOrder()).asIntBuffer();
val fa = new float[fb.capacity()];
for (int e = 0; e < fb.capacity(); e++)
fa[e] = (float) fb.get(e);
if (fa.length == 0)
throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
if (fa.length == 1)
return Nd4j.trueScalar(fa[0]);
if (arrayShape.length == 1)
return Nd4j.trueVector(fa);
val array = Nd4j.create(fa, arrayShape, 'c', 0);
//log.debug("SUM1: {}", array.sumNumber());
//log.debug("Data: {}", Arrays.toString(array.data().asFloat()));
return array;
}
} else if (tfTensor.getDtype() == DataType.DT_FLOAT) {
if (tfTensor.getFloatValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if(tfTensor.getFloatValCount() < 1)
return Nd4j.scalar(0.0);
float val = tfTensor.getFloatVal(0);
if (arrayShape == null || arrayShape.length == 0)
arrayShape = new int[]{};
INDArray array = Nd4j.valueArrayOf(arrayShape, (double) val);
return array;
} else if (tfTensor.getFloatValCount() > 0) {
float[] jArray = new float[tfTensor.getFloatValCount()];
for (int e = 0; e < tfTensor.getFloatValCount(); e++) {
jArray[e] = tfTensor.getFloatVal(e);
}
// FIXME: we're missing float[] signature
INDArray array = Nd4j.create(Nd4j.createBuffer(jArray), arrayShape, 'c');
return array;
} else if (tfTensor.getTensorContent().size() > 0){
// binary representation
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
val fb = bb.order(ByteOrder.nativeOrder()).asFloatBuffer();
val fa = new float[fb.capacity()];
for (int e = 0; e < fb.capacity(); e++)
fa[e] = fb.get(e);
if (fa.length == 0)
throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
if (fa.length == 1)
return Nd4j.trueScalar(fa[0]);
if (arrayShape.length == 1)
return Nd4j.trueVector(fa);
val array = Nd4j.create(fa, arrayShape, 'c', 0);
return array;
}
} else if (tfTensor.getDtype() == DataType.DT_DOUBLE) {
if (tfTensor.getDoubleValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if(tfTensor.getDoubleValCount() < 1)
return Nd4j.trueScalar(0.0);
double val = tfTensor.getDoubleVal(0);
INDArray array = Nd4j.trueScalar(val);
return array;
} else if (tfTensor.getDoubleValCount() > 0) {
double[] jArray = new double[tfTensor.getDoubleValCount()];
for (int e = 0; e < tfTensor.getDoubleValCount(); e++) {
jArray[e] = tfTensor.getDoubleVal(e);
}
// TF arrays are always C
INDArray array = Nd4j.create(jArray, arrayShape, 0, 'c');
return array;
} else if (tfTensor.getTensorContent().size() > 0) {
// binary representation
//DataBuffer buffer = Nd4j.createBuffer(tfTensor.getTensorContent().asReadOnlyByteBuffer(), DataBuffer.Type.FLOAT, (int) length);
//INDArray array = Nd4j.createArrayFromShapeBuffer(buffer, Nd4j.getShapeInfoProvider().createShapeInformation(arrayShape, 'c'));
// binary representation
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
val fb = bb.order(ByteOrder.nativeOrder()).asDoubleBuffer();
val da = new double[fb.capacity()];
for (int e = 0; e < fb.capacity(); e++)
da[e] = fb.get(e);
if (da.length == 0)
throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
if (da.length == 1)
return Nd4j.trueScalar(da[0]);
if (arrayShape.length == 1)
return Nd4j.trueVector(da);
val array = Nd4j.create(da, arrayShape, 0, 'c');
return array;
}
} else if (tfTensor.getDtype() == DataType.DT_INT64) {
if (tfTensor.getInt64ValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) {
//straight zero case
if(tfTensor.getDoubleValCount() < 1)
return Nd4j.trueScalar(0.0);
double val = (double) tfTensor.getInt64Val(0);
INDArray array = Nd4j.trueScalar(val);
return array;
} else if (tfTensor.getInt64ValCount() > 0) {
double[] jArray = new double[tfTensor.getInt64ValCount()];
for (int e = 0; e < tfTensor.getInt64ValCount(); e++) {
jArray[e] = (double) tfTensor.getInt64Val(e);
}
// TF arrays are always C
INDArray array = Nd4j.create(jArray, arrayShape, 0, 'c');
return array;
} else if (tfTensor.getTensorContent().size() > 0){
//throw new UnsupportedOperationException("To be implemented yet");
//Mapping INT bytebuffers should be converted to floating point
val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
val lb = bb.order(ByteOrder.nativeOrder()).asLongBuffer();
val fa = new float[lb.capacity()];
for (int e = 0; e < lb.capacity(); e++)
fa[e] = (float) lb.get(e);
if (fa.length == 0)
throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
if (fa.length == 1)
return Nd4j.trueScalar(fa[0]);
if (arrayShape.length == 1)
return Nd4j.trueVector(fa);
val array = Nd4j.create(fa, arrayShape, 'c', 0);
//log.debug("SUM1: {}", array.sumNumber());
//log.debug("Data: {}", Arrays.toString(array.data().asFloat()));
return array;
}
} else {
throw new UnsupportedOperationException("Unknown dataType found: [" + tfTensor.getDtype() + "]");
}
throw new ND4JIllegalStateException("Invalid method state");
}
@Override
public int[] getShapeFromTensor(NodeDef tensorProto) {
if(tensorProto.containsAttr("shape")) {
return shapeFromShapeProto(tensorProto.getAttrOrThrow("shape").getShape());
}
//yet to be determined shape, or tied to an op where output shape is dynamic
else if(!tensorProto.containsAttr("value")) {
return null;
}
else
return shapeFromShapeProto(tensorProto.getAttrOrThrow("value").getTensor().getTensorShape());
}
@Override
public Set opsToIgnore() {
return graphMapper;
}
@Override
public String getInputFromNode(NodeDef node, int index) {
return node.getInput(index);
}
@Override
public int numInputsFor(NodeDef nodeDef) {
return nodeDef.getInputCount();
}
private int[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) {
int[] shape = new int[tensorShapeProto.getDimList().size()];
for(int i = 0; i < shape.length; i++) {
shape[i] = (int) tensorShapeProto.getDim(i).getSize();
}
//shape should be mapped to a row vector
if(shape.length < 2) {
if(shape.length == 1)
shape = new int[]{1,shape[0]};
else
shape = new int[]{1,1};
}
return shape;
}
/**
* Returns the node for an if statement
* @param from the starting node (a merge node that represents a conditional)
* @param graph the graph to search
* @return an import state representing the nodes for each scope
*/
public IfImportState nodesForIf(NodeDef from, GraphDef graph) {
//Assume we start with a switch statement
int currNodeIndex = graph.getNodeList().indexOf(from);
val trueDefName = from.getInput(1);
val falseDefName = from.getInput(0);
val scopeId = UUID.randomUUID().toString();
val scopeName = scopeId + "-" + trueDefName.substring(0,trueDefName.indexOf("/"));
val trueDefScopeName = scopeName + "-true-scope";
val falseDefScopeName = scopeName + "-false-scope";
boolean onFalseDefinition = true;
//start with the true
boolean onTrueDefinition = false;
List falseBodyNodes = new ArrayList<>();
List trueBodyNodes = new ArrayList<>();
List conditionNodes = new ArrayList<>();
Set seenNames = new LinkedHashSet<>();
/**
* Accumulate a list backwards to get proper ordering.
*
*/
for(int i = currNodeIndex; i >= 0; i--) {
//switch to false names
if(graph.getNode(i).getName().equals(trueDefName)) {
onFalseDefinition = false;
onTrueDefinition = true;
}
//on predicate now
if(graph.getNode(i).getName().contains("pred_id")) {
onTrueDefinition = false;
}
//don't readd the same node, this causes a stackoverflow
if(onTrueDefinition && !graph.getNode(i).equals(from)) {
trueBodyNodes.add(graph.getNode(i));
}
else if(onFalseDefinition && !graph.getNode(i).equals(from)) {
falseBodyNodes.add(graph.getNode(i));
}
//condition scope now
else {
val currNode = graph.getNode(i);
if(currNode.equals(from))
continue;
//break only after bootstrapping the first node (the predicate id node)
if(!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) {
break;
}
/**
* Continuously add inputs seen for each node in the sub graph that occurs.
* Starting from the predicate id, any node that has inputs in the condition scope
* are by definition within the scope. Any node not encountered after that is considered out of scope.
* This means we break.
*/
for(int inputIdx = 0; inputIdx < currNode.getInputCount(); inputIdx++) {
seenNames.add(currNode.getInput(inputIdx));
}
//ensure the "current node" is added as well
seenNames.add(graph.getNode(i).getName());
conditionNodes.add(graph.getNode(i));
}
}
/**
* Since we are going over the graph backwards,
* we need to reverse the nodes to ensure proper ordering.
*/
Collections.reverse(falseBodyNodes);
Collections.reverse(trueBodyNodes);
Collections.reverse(conditionNodes);
return IfImportState.builder()
.condNodes(conditionNodes)
.falseNodes(falseBodyNodes)
.trueNodes(trueBodyNodes)
.conditionBodyScopeName(falseDefScopeName)
.falseBodyScopeName(falseDefScopeName)
.trueBodyScopeName(trueDefScopeName)
.conditionBodyScopeName(scopeName)
.build();
}
}