All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.imports.converters.DifferentialFunctionClassHolder Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.imports.converters;

import com.google.common.collect.ImmutableSet;
import com.google.common.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser;
import org.nd4j.imports.descriptors.onnx.OpDescriptor;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.OpDef;

import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.*;

@Slf4j
public class DifferentialFunctionClassHolder {
    private Map nodeConverters = new HashMap<>();
    private static DifferentialFunctionClassHolder INSTANCE = new DifferentialFunctionClassHolder();
    private Map tensorFlowNames = new HashMap<>();
    private Map onnxNames = new HashMap<>();
    private List missingOps = new ArrayList<>();

    private Map onnxOpDescriptors;
    private Map tensorflowOpDescriptors;
    private Map> fieldsForFunction;

    private  Set  fieldNamesOpsIgnore;
    private Set classesWithConfig = new LinkedHashSet(){{
        add(AvgPooling2D.class.getName());
        add(Conv2D.class.getName());
        add(Conv3D.class.getName());
        add(FullConv3D.class.getName());
        add(LocalResponseNormalization.class.getName());
        add(MaxPooling2D.class.getName());
        add(Pooling2D.class.getName());
        add(Pooling3D.class.getName());
        add(DepthwiseConv2D.class.getName());
    }};
    /**
     * Get the fields for a given {@link DifferentialFunction}
     * @param function the function to get the fields for
     * @return the fields for a given function
     */
    public Map getFieldsForFunction(DifferentialFunction function) {
        return fieldsForFunction.get(function.opName());
    }

    /**
     * Get the op definition of a given
     * tensorflow op.
     *
     * Note that if the name does not exist,
     * an {@link ND4JIllegalStateException} will be thrown
     * @param name the name of the op
     * @return the op definition for a given op
     */
    public OpDef getOpDefByTensorflowName(String name) {
        if(!tensorflowOpDescriptors.containsKey(name)) {
            throw new ND4JIllegalStateException("No op found with name " + name);
        }

        return tensorflowOpDescriptors.get(name);
    }

    /**
     * Get the op definition of a given
     * onnx op
     * Note that if the name does not exist,
     * an {@link ND4JIllegalStateException}
     * will be thrown.
     * @param name the name of the op
     * @return the op definition for a given op
     */
    public OpDescriptor getOpDescriptorForOnnx(String name) {
        if(!onnxOpDescriptors.containsKey(name)) {
            throw new ND4JIllegalStateException("No op found with name " + name);
        }

        return onnxOpDescriptors.get(name);
    }

    /**
     * Get the
     * @param tensorflowName
     * @return
     */
    public DifferentialFunction getOpWithTensorflowName(String tensorflowName) {
        return tensorFlowNames.get(tensorflowName);
    }

    public DifferentialFunction getOpWithOnnxName(String onnxName) {
        return onnxNames.get(onnxName);
    }

    private DifferentialFunctionClassHolder() {
        fieldNamesOpsIgnore = new LinkedHashSet(){{
            add("extraArgs");
            add("arrayInitialized");
            add("log");
            add("inputArguments");
            add("outputArguments");
            add("outputShapes");
            add("outputVariables");
            add("tArguments");
            add("iArguments");
            add("hash");
            add("opName");
            add("sameDiff");
            add("ownName");
        }};


        //Scan classpath to find all DifferentialFunction instances, so tensorflow/onnx mappings can be made
        //We're assuming here that all instances with such mappings are defined in ND4J
        //As of 04/2018 all DifferentialFunction classes are defined in org.nd4j.linalg.api.ops - with the exception
        // of ILossFunction instances, which don't have TF/Onnx import working anyway
        ImmutableSet info;
        try {
            //Dependency note: this ClassPath class was added in Guava 14
            info = com.google.common.reflect.ClassPath.from(DifferentialFunctionClassHolder.class.getClassLoader())
                    .getTopLevelClassesRecursive("org.nd4j.linalg.api.ops");
        } catch (IOException e){
            //Should never happen
            throw new RuntimeException(e);
        }

        fieldsForFunction = new LinkedHashMap<>();

        int count = 0;
        for(ClassPath.ClassInfo c : info){
            //Load method: Loads (but doesn't link or initialize) the class.
            Class clazz;
            try{
                clazz = Class.forName(c.getName());
            } catch (ClassNotFoundException e){
                //Should never happen as  this was found on the classpath
                throw new RuntimeException(e);
            }


            if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || !DifferentialFunction.class.isAssignableFrom(clazz))
                continue;

            try {
                DifferentialFunction node = (DifferentialFunction)clazz.newInstance();
                val name = node.opName();
                if(name == null)
                    continue;

                if(name.endsWith("_bp")) {
                    //log.warn("Skipping derivative " + name);
                }
                if (nodeConverters.containsKey(name)) {
                    throw new ND4JIllegalStateException("OpName duplicate found: " + name);
                } else {
                    //log.info("Adding converter for [" + name + "]");
                    nodeConverters.put(name, node);
                    try {
                        for(String s : node.tensorflowNames())
                            tensorFlowNames.put(s,node);
                    }catch (NoOpNameFoundException e) {
                        log.trace("Skipping op " + name + " for tensorflow.");
                    }

                    try {
                        onnxNames.put(node.onnxName(),node);
                    }catch (NoOpNameFoundException e) {
                        log.trace("Skipping op " + name + " for onnx.");
                    }

                    //accumulate the field names for a given function
                    //this is mainly used in import
                    Map fieldNames = new LinkedHashMap<>();
                    Class current = node.getClass();
                    val fields = new ArrayList();
                    while(current.getSuperclass() != null) {
                        if(classesWithConfig.contains(current.getName())) {

                            val fieldName = "config";

                            val configField = current.getDeclaredField(fieldName);
                            if(configField ==  null) {
                                continue;
                            }

                            val configFieldClass = configField.getType();

                            for(val field : configFieldClass.getDeclaredFields()) {
                                if(!fieldNamesOpsIgnore.contains(field.getName())) {
                                    fields.add(field);
                                    field.setAccessible(true);
                                    fieldNames.put(field.getName(),field);
                                }
                            }
                        }
                        else {
                            for(val field : current.getDeclaredFields()) {
                                if(!fieldNamesOpsIgnore.contains(field.getName())) {
                                    fields.add(field);
                                    field.setAccessible(true);
                                    fieldNames.put(field.getName(),field);
                                }
                            }
                        }

                        // do something with current's fields
                        current = (Class) current.getSuperclass();

                    }

                    fieldsForFunction.put(node.opName(),fieldNames);

                }
            } catch (NoOpNameFoundException e) {
                log.trace("Skipping function  " + clazz);
            } catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }



        //get the op descriptors for onnx and tensorflow
        //this is used when validating operations
        try {
            tensorflowOpDescriptors = TensorflowDescriptorParser.opDescs();
            onnxOpDescriptors = OnnxDescriptorParser.onnxOpDescriptors();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }


        val map = new HashMap<>(Nd4j.getExecutioner().getCustomOperations());
        val set = map.keySet();
        set.removeAll(nodeConverters.keySet());
        missingOps.addAll(set);
        Collections.sort(missingOps);
        log.warn("Missing " + set.size() + " ops!");

    }


    /***
     * Returns the missing onnx ops
     * @return
     */
    public Set missingOnnxOps() {
        Set copy = new HashSet<>(onnxOpDescriptors.keySet());
        copy.removeAll(onnxNames.keySet());
        return copy;
    }


    /***
     * Returns the missing tensorflow ops
     * @return
     */
    public Set missingTensorflowOps() {
        Set copy = new HashSet<>(tensorflowOpDescriptors.keySet());
        copy.removeAll(tensorFlowNames.keySet());
        return copy;
    }

    /**
     * Returns the missing ops
     * for c++ vs java.
     * @return
     */
    public List missingOps() {
        return missingOps;
    }

    /**
     *
     * @param name
     * @return
     */
    public boolean hasName(String name) {
        return nodeConverters.containsKey(name);
    }


    public Set opNames() {
        return nodeConverters.keySet();
    }

    /**
     *
     * @param name
     * @return
     */
    public DifferentialFunction getInstance(String name) {
        return nodeConverters.get(name);
    }

    public static DifferentialFunctionClassHolder getInstance() {
        return INSTANCE;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy