Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.imports.graphmapper;
import com.github.os72.protobuf351.Message;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Map graph proto types to
*
* {@link SameDiff} instances
* @param the proto type for the graph
* @param the proto type for the node
* @param the proto type for the attribute
* @param the proto type for the tensor
*@author Adam Gibson
*/
public interface GraphMapper {
/**
* Import a graph as SameDiff from the given file
* @param graphFile Input stream pointing to graph file to import
* @return Imported graph
*/
SameDiff importGraph(InputStream graphFile);
SameDiff importGraph(InputStream graphFile, Map> opImportOverrides,
OpImportFilter opFilter);
/**
* Import a graph as SameDiff from the given file
* @param graphFile Graph file to import
* @return Imported graph
* @see #importGraph(File, Map)
*/
SameDiff importGraph(File graphFile);
/**
* Import a graph as SameDiff from the given file, with optional op import overrides.
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops
* that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions.
*
* @param graphFile Graph file to import
* @param opImportOverrides May be null. If non-null: used to import the specified operations. Key is the name of the
* operation to import, value is the object used to import it
* @return Imported graph
*/
SameDiff importGraph(File graphFile, Map> opImportOverrides,
OpImportFilter opFilter);
/**
* This method converts given graph type (in its native format) to SameDiff
* @param graph Graph to import
* @return Imported graph
*/
SameDiff importGraph(GRAPH_TYPE graph);
/**
* This method converts given graph type (in its native format) to SameDiff
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops
* that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions.
* @param graph Graph to import
* @return Imported graph
*/
SameDiff importGraph(GRAPH_TYPE graph, Map> opImportOverrides,
OpImportFilter opFilter);
/**
* Returns true if this node is a special case
* (maybe because of name or other scenarios)
* that should override {@link #opsToIgnore()}
* in certain circumstances
* @param node the node to check
* @return true if this node is an exception false otherwise
*/
boolean isOpIgnoreException(NODE_TYPE node);
/**
* Get the nodes sorted by n ame
* from a given graph
* @param graph the graph to get the nodes for
* @return the map of the nodes by name
* for a given graph
*/
Map nodesByName(GRAPH_TYPE graph);
/**
* Get the target mapping key (usually based on the node name)
* for the given function
* @param function the function
* @param node the node to derive the target mapping from
* @return
*/
String getTargetMappingForOp(DifferentialFunction function, NODE_TYPE node);
/**
*
* @param on
* @param node
* @param graph
* @param sameDiff
* @param propertyMappings
*/
void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappings);
/**
*
* @param name
* @param on
* @param node
* @param graph
* @param sameDiff
* @param propertyMappingsForFunction
*/
void mapProperty(String name, DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappingsForFunction);
/**
* Get the node from the graph
* @param graph the graph to get the node from
* @param name the name of the node to get from the graph
* @return
*/
NODE_TYPE getNodeWithNameFromGraph(GRAPH_TYPE graph,String name);
/**
* Returns true if the given node is a place holder
* @param node the node to check
* @return true if the node is a place holder or not
*/
boolean isPlaceHolderNode(TENSOR_TYPE node);
/**
* Get the list of control dependencies for the current node (or null if none exist)
*
* @param node Node to get the control dependencies (if any) for
* @return
*/
List getControlDependencies(NODE_TYPE node);
/**
* Dump a binary proto file representation as a
* plain string in to the target text file
* @param inputFile
* @param outputFile
*/
void dumpBinaryProtoAsText(File inputFile,File outputFile);
/**
* Dump a binary proto file representation as a
* plain string in to the target text file
* @param inputFile
* @param outputFile
*/
void dumpBinaryProtoAsText(InputStream inputFile,File outputFile);
/**
* Get the mapped op name
* for a given op
* relative to the type of node being mapped.
* The input name should be based on a tensorflow
* type or onnx type, not the nd4j name
* @param name the tensorflow or onnx name
* @return the function based on the values in
* {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder}
*/
DifferentialFunction getMappedOp(String name);
/**
* Get the variables for the given graph
* @param graphType the graph to get the variables for
* @return a map of variable name to tensor
*/
Map variablesForGraph(GRAPH_TYPE graphType);
/**
*
* @param name
* @param node
* @return
*/
String translateToSameDiffName(String name, NODE_TYPE node);
/**
*
* @param graph
* @return
*/
Map nameIndexForGraph(GRAPH_TYPE graph);
/**
* Returns an op type for the given input node
* @param nodeType the node to use
* @return the optype for the given node
*/
Op.Type opTypeForNode(NODE_TYPE nodeType);
/**
* Returns a graph builder for initial definition and parsing.
* @return
*/
Message.Builder getNewGraphBuilder();
/**
* Parse a graph from an input stream
* @param inputStream the input stream to load from
* @return
*/
GRAPH_TYPE parseGraphFrom(byte[] inputStream) throws IOException;
/**
* Parse a graph from an input stream
* @param inputStream the input stream to load from
* @return
*/
GRAPH_TYPE parseGraphFrom(InputStream inputStream) throws IOException;
/**
* Map a node in to the import state covering the {@link SameDiff} instance
* @param tfNode the node to map
* @param importState the current import state
* @param opFilter Optional filter for skipping operations
*/
void mapNodeType(NODE_TYPE tfNode, ImportState importState,
OpImportOverride opImportOverride,
OpImportFilter opFilter);
/**
*
* @param tensorType
* @param outputNum
* @return
*/
DataType dataTypeForTensor(TENSOR_TYPE tensorType, int outputNum);
boolean isStringType(TENSOR_TYPE tensor);
/**
*
* @param nodeType
* @param key
* @return
*/
String getAttrValueFromNode(NODE_TYPE nodeType,String key);
/**
*
* @param attrType
* @return
*/
long[] getShapeFromAttribute(ATTR_TYPE attrType);
/**
* Returns true if the given node is a place holder type
* (think a yet to be determined shape)_
* @param nodeType
* @return
*/
boolean isPlaceHolder(TENSOR_TYPE nodeType);
/**
* Returns true if the given node is a constant
* @param nodeType
* @return
*/
boolean isConstant(TENSOR_TYPE nodeType);
/**
*
*
* @param tensorName
* @param tensorType
* @param graph
* @return
*/
INDArray getNDArrayFromTensor(String tensorName, TENSOR_TYPE tensorType, GRAPH_TYPE graph);
/**
* Get the shape for the given tensor type
* @param tensorType
* @return
*/
long[] getShapeFromTensor(TENSOR_TYPE tensorType);
/**
* Ops to ignore for mapping
* @return
*/
Set opsToIgnore();
/**
* Get the input node for the given node
* @param node the node
* @param index hte index
* @return
*/
String getInputFromNode(NODE_TYPE node, int index);
/**
* Get the number of inputs for a node.
* @param nodeType the node to get the number of inputs for
* @return
*/
int numInputsFor(NODE_TYPE nodeType);
/**
* Whether the data type for the tensor is valid
* for creating an {@link INDArray}
* @param tensorType the tensor proto to test
* @return
*/
boolean validTensorDataType(TENSOR_TYPE tensorType);
/**
* Get the shape of the attribute value
* @param attr the attribute value
* @return the shape of the attribute if any or null
*/
long[] getShapeFromAttr(ATTR_TYPE attr);
/**
* Get the attribute
* map for given node
* @param nodeType the node
* @return the attribute map for the attribute
*/
Map getAttrMap(NODE_TYPE nodeType);
/**
* Get the name of the node
* @param nodeType the node
* to get the name for
* @return
*/
String getName(NODE_TYPE nodeType);
/**
*
* @param nodeType
* @return
*/
boolean alreadySeen(NODE_TYPE nodeType);
/**
*
* @param nodeType
* @return
*/
boolean isVariableNode(NODE_TYPE nodeType);
/**
*
*
* @param opType
* @return
*/
boolean shouldSkip(NODE_TYPE opType);
/**
*
* @param nodeType
* @return
*/
boolean hasShape(NODE_TYPE nodeType);
/**
*
* @param nodeType
* @return
*/
long[] getShape(NODE_TYPE nodeType);
/**
*
* @param nodeType
* @param graph
* @return
*/
INDArray getArrayFrom(NODE_TYPE nodeType, GRAPH_TYPE graph);
String getOpType(NODE_TYPE nodeType);
/**
*
* @param graphType
* @return
*/
List getNodeList(GRAPH_TYPE graphType);
}