All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.tensorflow.conversion.ProtoBufToFlatBufConversion Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.tensorflow.conversion;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.transform.*;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.imports.tensorflow.TFImportOverride;
import org.nd4j.imports.tensorflow.TFOpImportFilter;
import org.nd4j.linalg.api.buffer.DataType;

import java.io.File;
import java.io.IOException;
import java.util.*;

/**
 * Conversion from models saved using the Google's Protocol Buffer
 * (https://github.com/protocolbuffers/protobuf) to flatbuffer format
 * (https://github.com/google/flatbuffers)
 *
 * This is especially useful for executing models using only the C++ libnd4j
 * library, as the protobuf loader is only available through the Java API
 *
 * It simply loads a file as a SameDiff and saves it as a flat file.
 *
 * There is a special case for BERT models where a pre-processing is necessary:
 * See nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java
 * for details
 *
 * @author Yves Quemener
 */
public class ProtoBufToFlatBufConversion {

    /**
     * Converts a file containing a model from the Protocol Buffer format to the Flat
     * Buffer format.
     * @param inFile input file (.pb format)
     * @param outFile output file (.fb format)
     * @throws IOException
     * @throws org.nd4j.linalg.exception.ND4JIllegalStateException
     */
    public static void convert(String inFile, String outFile)
                    throws IOException, org.nd4j.linalg.exception.ND4JIllegalStateException {
        SameDiff tg = TFGraphMapper.importGraph(new File(inFile));
        tg.asFlatFile(new File(outFile));
    }

    /**
     * Converts a BERT model from the Protocol Buffer format to the Flat Buffer format.
     * @param inFile input file (.pb format)
     * @param outFile output file (.fb format)
     * @throws IOException
     * @throws org.nd4j.linalg.exception.ND4JIllegalStateException
     */
    public static void convertBERT(String inFile, String outFile)
                    throws IOException, org.nd4j.linalg.exception.ND4JIllegalStateException {
        //
        // Working around some issues in the BERT model's execution. See file:
        // nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java
        // for details.

        int minibatchSize = 4;
        Map m = new HashMap<>();
        m.put("IteratorGetNext", (inputs, controlDepInputs, nodeDef, initWith, attributesForNode, graph) -> {
            // Return 3 placeholders called "IteratorGetNext:0", "IteratorGetNext:1", "IteratorGetNext:3" instead of the
            // training iterator
            return Arrays.asList(initWith.placeHolder("IteratorGetNext", DataType.INT, minibatchSize, 128),
                            initWith.placeHolder("IteratorGetNext:1", DataType.INT, minibatchSize, 128),
                            initWith.placeHolder("IteratorGetNext:4", DataType.INT, minibatchSize, 128));
        });

        // Skip the "IteratorV2" op - we don't want or need this
        TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> {
            return "IteratorV2".equals(nodeDef.getName());
        };


        SameDiff sd = TFGraphMapper.importGraph(new File(inFile), m, filter);


        SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/mul")) // .../dropout/mul
                                                                                                    // is the output
                                                                                                    // variable, post
                                                                                                    // dropout
                        .withInputCount(2)
                        .withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/div"))) // .../dropout/div
                                                                                                                     // is
                                                                                                                     // the
                                                                                                                     // first
                                                                                                                     // input.
                                                                                                                     // "withInputS
                        .withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/Floor"))
                                        .withInputSubgraph(0, SubGraphPredicate
                                                        .withRoot(OpPredicate.nameMatches(".*/dropout/add"))
                                                        .withInputSubgraph(1, SubGraphPredicate
                                                                        .withRoot(OpPredicate.nameMatches(
                                                                                        ".*/dropout/random_uniform"))
                                                                        .withInputSubgraph(0, SubGraphPredicate
                                                                                        .withRoot(OpPredicate
                                                                                                        .nameMatches(".*/dropout/random_uniform/mul"))
                                                                                        .withInputSubgraph(0,
                                                                                                        SubGraphPredicate
                                                                                                                        .withRoot(OpPredicate
                                                                                                                                        .nameMatches(".*/dropout/random_uniform/RandomUniform")))
                                                                                        .withInputSubgraph(1,
                                                                                                        SubGraphPredicate
                                                                                                                        .withRoot(OpPredicate
                                                                                                                                        .nameMatches(".*/dropout/random_uniform/sub")))

                                                                        ))));

        List subGraphs = GraphTransformUtil.getSubgraphsMatching(sd, p);
        int subGraphCount = subGraphs.size();
        sd = GraphTransformUtil.replaceSubgraphsMatching(sd, p, new SubGraphProcessor() {
            @Override
            public List processSubgraph(SameDiff sd, SubGraph subGraph) {
                List inputs = subGraph.inputs(); // Get inputs to the subgraph
                // Find pre-dropout input variable:
                SDVariable newOut = null;
                for (SDVariable v : inputs) {
                    if (v.getVarName().endsWith("/BiasAdd") || v.getVarName().endsWith("/Softmax")
                                    || v.getVarName().endsWith("/add_1") || v.getVarName().endsWith("/Tanh")) {
                        newOut = v;
                        break;
                    }
                }

                if (newOut != null) {
                    // Pass this input variable as the new output
                    return Collections.singletonList(newOut);
                }

                throw new RuntimeException("No pre-dropout input variable found");
            }
        });


        System.out.println("Exporting file " + outFile);
        sd.asFlatFile(new File(outFile));
    }


    /**
     * Main function.
     * The conversion tool can be called from the command line with the floowing syntax:
     * mvn exec:java -Dexec.mainClass="org.nd4j.tensorflow.conversion.ProtoBufToFlatBufConversion" -Dexec.args=" "
     *
     * @param args the first argument is the input filename (protocol buffer format),
     *             the second one is the output filename (flat buffer format)
     * @throws IOException
     */
    public static void main(String[] args) throws IOException {
        if (args.length < 2) {
            System.err.println("Usage:\n"
                            + "mvn exec:java -Dexec.mainClass=\"org.nd4j.tensorflow.conversion.ProtoBufToFlatBufConversion\" -Dexec.args=\" \"\n");
        } else {
            convert(args[0], args[1]);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy