All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.imports.graphmapper.GraphMapper Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.imports.graphmapper;

import org.nd4j.shade.protobuf.Message;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Map graph proto types to
 *
 * {@link SameDiff} instances
 * @param  the proto type for the graph
 * @param  the proto type for the node
 * @param  the proto type for the attribute
 * @param  the proto type for the tensor
 *@author Adam Gibson
 */
public interface GraphMapper {

    /**
     * Import a graph as SameDiff from the given file
     * @param graphFile Input stream pointing to graph file to import
     * @return Imported graph
     */
    SameDiff importGraph(InputStream graphFile);

    SameDiff importGraph(InputStream graphFile, Map> opImportOverrides,
                         OpImportFilter opFilter);

    /**
     * Import a graph as SameDiff from the given file
     * @param graphFile Graph file to import
     * @return Imported graph
     * @see #importGraph(File, Map)
     */
    SameDiff importGraph(File graphFile);

    /**
     * Import a graph as SameDiff from the given file, with optional op import overrides.
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops * that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions. * * @param graphFile Graph file to import * @param opImportOverrides May be null. If non-null: used to import the specified operations. Key is the name of the * operation to import, value is the object used to import it * @return Imported graph */ SameDiff importGraph(File graphFile, Map> opImportOverrides, OpImportFilter opFilter); /** * This method converts given graph type (in its native format) to SameDiff * @param graph Graph to import * @return Imported graph */ SameDiff importGraph(GRAPH_TYPE graph); /** * This method converts given graph type (in its native format) to SameDiff
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops * that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions. * @param graph Graph to import * @return Imported graph */ SameDiff importGraph(GRAPH_TYPE graph, Map> opImportOverrides, OpImportFilter opFilter); /** * Returns true if this node is a special case * (maybe because of name or other scenarios) * that should override {@link #opsToIgnore()} * in certain circumstances * @param node the node to check * @return true if this node is an exception false otherwise */ boolean isOpIgnoreException(NODE_TYPE node); /** * Get the nodes sorted by n ame * from a given graph * @param graph the graph to get the nodes for * @return the map of the nodes by name * for a given graph */ Map nodesByName(GRAPH_TYPE graph); /** * Get the target mapping key (usually based on the node name) * for the given function * @param function the function * @param node the node to derive the target mapping from * @return */ String getTargetMappingForOp(DifferentialFunction function, NODE_TYPE node); /** * * @param on * @param node * @param graph * @param sameDiff * @param propertyMappings */ void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappings); /** * * @param name * @param on * @param node * @param graph * @param sameDiff * @param propertyMappingsForFunction */ void mapProperty(String name, DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map> propertyMappingsForFunction); /** * Get the node from the graph * @param graph the graph to get the node from * @param name the name of the node to get from the graph * @return */ NODE_TYPE getNodeWithNameFromGraph(GRAPH_TYPE graph,String name); /** * Returns true if the given node is a place holder * @param node the node to check * @return true if the node is a place holder or not */ boolean isPlaceHolderNode(TENSOR_TYPE node); /** * Get the list of control dependencies for the current node (or null if none exist) * * @param node Node to get the control dependencies (if any) for * @return */ List getControlDependencies(NODE_TYPE node); /** * Dump a binary proto file representation as a * plain string in to the target text file * @param inputFile * @param outputFile */ void dumpBinaryProtoAsText(File inputFile,File outputFile); /** * Dump a binary proto file representation as a * plain string in to the target text file * @param inputFile * @param outputFile */ void dumpBinaryProtoAsText(InputStream inputFile,File outputFile); /** * Get the mapped op name * for a given op * relative to the type of node being mapped. * The input name should be based on a tensorflow * type or onnx type, not the nd4j name * @param name the tensorflow or onnx name * @return the function based on the values in * {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder} */ DifferentialFunction getMappedOp(String name); /** * Get the variables for the given graph * @param graphType the graph to get the variables for * @return a map of variable name to tensor */ Map variablesForGraph(GRAPH_TYPE graphType); /** * * @param name * @param node * @return */ String translateToSameDiffName(String name, NODE_TYPE node); /** * * @param graph * @return */ Map nameIndexForGraph(GRAPH_TYPE graph); /** * Returns an op type for the given input node * @param nodeType the node to use * @return the optype for the given node */ Op.Type opTypeForNode(NODE_TYPE nodeType); /** * Returns a graph builder for initial definition and parsing. * @return */ Message.Builder getNewGraphBuilder(); /** * Parse a graph from an input stream * @param inputStream the input stream to load from * @return */ GRAPH_TYPE parseGraphFrom(byte[] inputStream) throws IOException; /** * Parse a graph from an input stream * @param inputStream the input stream to load from * @return */ GRAPH_TYPE parseGraphFrom(InputStream inputStream) throws IOException; /** * Map a node in to the import state covering the {@link SameDiff} instance * @param tfNode the node to map * @param importState the current import state * @param opFilter Optional filter for skipping operations */ void mapNodeType(NODE_TYPE tfNode, ImportState importState, OpImportOverride opImportOverride, OpImportFilter opFilter); /** * * @param tensorType * @param outputNum * @return */ DataType dataTypeForTensor(TENSOR_TYPE tensorType, int outputNum); boolean isStringType(TENSOR_TYPE tensor); /** * * @param nodeType * @param key * @return */ String getAttrValueFromNode(NODE_TYPE nodeType,String key); /** * * @param attrType * @return */ long[] getShapeFromAttribute(ATTR_TYPE attrType); /** * Returns true if the given node is a place holder type * (think a yet to be determined shape)_ * @param nodeType * @return */ boolean isPlaceHolder(TENSOR_TYPE nodeType); /** * Returns true if the given node is a constant * @param nodeType * @return */ boolean isConstant(TENSOR_TYPE nodeType); /** * * * @param tensorName * @param tensorType * @param graph * @return */ INDArray getNDArrayFromTensor(String tensorName, TENSOR_TYPE tensorType, GRAPH_TYPE graph); /** * Get the shape for the given tensor type * @param tensorType * @return */ long[] getShapeFromTensor(TENSOR_TYPE tensorType); /** * Ops to ignore for mapping * @return */ Set opsToIgnore(); /** * Get the input node for the given node * @param node the node * @param index hte index * @return */ String getInputFromNode(NODE_TYPE node, int index); /** * Get the number of inputs for a node. * @param nodeType the node to get the number of inputs for * @return */ int numInputsFor(NODE_TYPE nodeType); /** * Whether the data type for the tensor is valid * for creating an {@link INDArray} * @param tensorType the tensor proto to test * @return */ boolean validTensorDataType(TENSOR_TYPE tensorType); /** * Get the shape of the attribute value * @param attr the attribute value * @return the shape of the attribute if any or null */ long[] getShapeFromAttr(ATTR_TYPE attr); /** * Get the attribute * map for given node * @param nodeType the node * @return the attribute map for the attribute */ Map getAttrMap(NODE_TYPE nodeType); /** * Get the name of the node * @param nodeType the node * to get the name for * @return */ String getName(NODE_TYPE nodeType); /** * * @param nodeType * @return */ boolean alreadySeen(NODE_TYPE nodeType); /** * * @param nodeType * @return */ boolean isVariableNode(NODE_TYPE nodeType); /** * * * @param opType * @return */ boolean shouldSkip(NODE_TYPE opType); /** * * @param nodeType * @return */ boolean hasShape(NODE_TYPE nodeType); /** * * @param nodeType * @return */ long[] getShape(NODE_TYPE nodeType); /** * * @param nodeType * @param graph * @return */ INDArray getArrayFrom(NODE_TYPE nodeType, GRAPH_TYPE graph); String getOpType(NODE_TYPE nodeType); /** * * @param graphType * @return */ List getNodeList(GRAPH_TYPE graphType); }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy