All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.functions.DifferentialFunction Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.autodiff.functions;

import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

import java.lang.reflect.Field;
import java.util.*;


@Data
@Slf4j
public abstract class DifferentialFunction {

    @Getter
    @Setter
    @JsonIgnore
    protected SameDiff sameDiff;

    @Getter
    @Setter
    @JsonIgnore
    protected boolean inPlace;



    @Getter
    @Setter
    @JsonIgnore
    protected INDArray scalarValue;


    @Getter
    @Setter
    @JsonIgnore
    protected int[] dimensions;

    @JsonIgnore
    protected Object[] extraArgs;


    @Getter
    @Setter
    @JsonIgnore
    private String ownName;

    public DifferentialFunction() {
        this(true);
    }

    public DifferentialFunction(boolean sameDiff){
        //Only need instance ID if using function in context of SameDiff, not standard ND4J with INDArray args
        if(sameDiff)
            setInstanceId();
    }

    /**
     * Initialize the function from the given
     * {@link NodeDef}
     * @param nodeDef
     */
    public DifferentialFunction(SameDiff sameDiff,NodeDef nodeDef, Map attributesForNode, GraphDef graph) {
        this.sameDiff = sameDiff;
        setInstanceId();
        initFromTensorFlow(nodeDef, sameDiff,attributesForNode ,graph);
    }

    /**
     * Initialize the function from the given
     * {@link onnx.Onnx.NodeProto}
     * @param node
     */
    public DifferentialFunction(SameDiff sameDiff,onnx.Onnx.NodeProto node,Map attributesForNode, Onnx.GraphProto graph) {
        this.sameDiff = sameDiff;
        setInstanceId();
        initFromOnnx(node, sameDiff, attributesForNode, graph);
    }


    /**
     * Returns the {@link AttributeAdapter} s for each of the
     * possible ops for import (typically tensorflow and onnx)
     *
     * See {@link AttributeAdapter} for more information on what the
     * adapter does.
     *
     * Similar to {@link #mappingsForFunction()}, the returned map
     * contains a {@link AttributeAdapter} for each field name
     * when one is present. (It is optional for one to exist)_
     * @return
     */
    public Map> attributeAdaptersForFunction() {
        return Collections.emptyMap();
    }

    /**
     * Returns the mappings for a given function (
     * for tensorflow and onnx import mapping properties
     * of this function). The mapping is indexed by field name.
     * If the function has no properties, this returned map
     * will be empty.
     *
     * Note that some functions have multiple names.
     * This function returns a map indexed by each
     * alias it has for a given name.
     * These names include both onnx and tensorflow names (which might be 1 or more)
     *
     * @return
     */
    public Map> mappingsForFunction() {
        return Collections.emptyMap();
    }

    /**
     * Returns the properties for a given function
     * @return
     */
    public Map propertiesForFunction() {
        Map fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        Map ret = new LinkedHashMap<>();
        Preconditions.checkNotNull(fields, "DifferentialFunctionClassHolder returned null fields for %s - op has not been added to ImportClassMapping?", getClass());

        for(val entry : fields.entrySet()) {
            try {
                ret.put(entry.getKey(),fields.get(entry.getKey()).get(this));
            } catch (IllegalAccessException e) {
                throw new RuntimeException("Unable to get property for field: " + entry.getKey(), e);
            }
        }

        return ret;
    }

    public void setPropertiesForFunction(Map properties){
        Map fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        for(String s : properties.keySet()){
            Field f = fields.get(s);
            if(f == null){
                log.warn("No fields found for property name {} for class {}", s, this.getClass().getName());
                continue;
            }
            setValueFor(f, properties.get(s));
        }
    }


    /**
     * Get the value for a given property
     * for this function
     * @param property the property to get
     * @return the value for the function if it exists
     */
    public Object getValue(Field property) {
        try {
            return property.get(this);
        } catch (IllegalAccessException e) {
            log.error("",e);
        }

        return null;
    }

    /**
     * Set the value for this function.
     * Note that if value is null an {@link ND4JIllegalStateException}
     * will be thrown.
     * @param target the target field
     * @param value the value to set
     */
    public void setValueFor(Field target, Object value) {
        if(value == null && target.getType().isPrimitive()) {
            throw new ND4JIllegalStateException("Unable to set primitive field " + target + " of type " + target.getClass()
                    + " using null value!");
        }

        if(value != null) {
            value = ensureProperType(target, value);
        }

        if(isConfigProperties()){
            String propertyName = configFieldName();
            if(propertyName == null)
                propertyName = "config";
            Field f = null;
            Class currClass = getClass();
            try{
                f = currClass.getDeclaredField(propertyName);
            } catch (NoSuchFieldException e){
                //OK, try superclass
            }
            while(f == null && currClass.getSuperclass() != null){
                currClass = currClass.getSuperclass();
                try{
                    f = currClass.getDeclaredField(propertyName);
                } catch (NoSuchFieldException e){
                    //OK, try superclass
                }
            }

            if(f == null){
                throw new IllegalStateException("Could not find field \"" + propertyName + "\" for class " + getClass().getName());
            }

            try {
                f.setAccessible(true);
                Object o = f.get(this);
                if(o == null){
                    //Null config class - try to create one...
                    Class c = f.getType();
                    try {
                        o = c.newInstance();
                    } catch (InstantiationException e){
                        throw new RuntimeException("Error creating new instance of configuration object type " + c.getName(), e);
                    }
                    f.set(this, o);
                }
                target.set(o, value);
            } catch (IllegalAccessException e){
                throw new RuntimeException("Error setting configuration field \"" + propertyName + "\" for config field \"" + propertyName
                    + "\" on class " + getClass().getName());
            }

        } else {
            try {
                //Edge case: we store float fields as doubles, rather than introduce an extra property
                if(target.getType() == float.class && value instanceof Double){
                    value = ((Double) value).floatValue();
                }

                target.set(this,value);
            } catch (IllegalAccessException e) {
                throw new RuntimeException("Error setting property for function " + getClass().getName(), e);
            }
        }
    }


    private Object ensureProperType(Field targetType,Object value) {
        val firstClass = targetType.getType();
        val valueType = value.getClass();

        if(!firstClass.equals(valueType)) {
            if(firstClass.isEnum()){
                if(valueType.equals(String.class)) {
                    Object[] enumConstants = firstClass.getEnumConstants();
                    for (int i = 0; i < enumConstants.length; i++) {
                        if (enumConstants[i].toString().equalsIgnoreCase((String) value)) {
                            return enumConstants[i];
                        }
                    }
                    throw new IllegalStateException("Could not find enum constant value for value \"" + value
                            + "\" for enum class " + firstClass.getName());
                }
            } else if(firstClass.equals(int[].class)) {
                if(value instanceof Number) {
                    Number number = (Number) value;
                    value = number.intValue();
                }

                int otherValue = (int) value;
                int[] setValue = new int[] {otherValue};
                return setValue;
            }
            else if(firstClass.equals(Integer[].class)) {
                if(value instanceof Number) {
                    Number number = (Number) value;
                    value = number.intValue();
                }

                Integer otherValue = (Integer) value;
                Integer[] setValue = new Integer[] {otherValue};
                return setValue;
            }
            else if(firstClass.equals(long[].class)) {
                if(value instanceof Number) {
                    Number number = (Number) value;
                    value = number.longValue();
                }

                long otherValue = (long) value;
                long[] setValue = new long[] {otherValue};
                return setValue;

            }
            else if(firstClass.equals(Long[].class)) {
                if(value instanceof Number) {
                    Number number = (Number) value;
                    value = number.longValue();
                }

                Long otherValue = (Long) value;
                Long[] setValue = new Long[] {otherValue};
                return setValue;

            }
            else if(firstClass.equals(double[].class)) {
                if(value instanceof Number) {
                    Number number = (Number) value;
                    value = number.doubleValue();
                }


                double otherValue = (double) value;
                double[] setValue = new double[] {otherValue};
                return setValue;

            }
            else if(firstClass.equals(Double[].class)) {
                if(value instanceof Number) {
                    Number number = (Number) value;
                    value = number.doubleValue();
                }


                Double otherValue = (Double) value;
                Double[] setValue = new Double[] {otherValue};
                return setValue;

            }
            else if(firstClass.equals(float[].class)) {
                if(value instanceof Number) {
                    Number number = (Number) value;
                    value = number.floatValue();
                }


                float otherValue = (float) value;
                float[] setValue = new float[] {otherValue};
                return setValue;

            }
            else if(firstClass.equals(Float[].class)) {
                if(value instanceof Number) {
                    Number number = (Number) value;
                    value = number.floatValue();
                }



                Float otherValue = (Float) value;
                Float[] setValue = new Float[] {otherValue};
                return setValue;

            }
        }

        return value;
    }


    /**
     * Returns true if the fields for this class should be looked up from a configuration class.
     * @return
     */
    public boolean isConfigProperties() {
        return false;
    }

    /**
     * Returns the name of the field to be used for looking up field names.
     * This should be used in conjunction with {@link #isConfigProperties()}
     *  to facilitate mapping fields for model import.
     * @return
     */
    public String configFieldName() {
        return null;
    }


    /**
     *
     * @param sameDiff
     * @param extraArgs
     */
    public DifferentialFunction(SameDiff sameDiff,boolean inPlace, Object[] extraArgs) {
        this.sameDiff = sameDiff;
        this.inPlace = inPlace;
        setInstanceId();
        this.extraArgs = extraArgs;
    }


    /**
     *
     * @param sameDiff
     * @param extraArgs
     */
    public DifferentialFunction(SameDiff sameDiff, Object[] extraArgs) {
        this.sameDiff = sameDiff;
        setInstanceId();
        this.extraArgs = extraArgs;
    }

    /**
     *
     * @param sameDiff
     * @param args
     */
    public DifferentialFunction(SameDiff sameDiff, SDVariable[] args) {
        this(sameDiff,false, args);
    }


    /**
     * Add the various arguments for
     * this function
     * @param sameDiff
     * @param inPlace
     * @param args
     */
    public DifferentialFunction(SameDiff sameDiff, boolean inPlace, SDVariable[] args) {
        this.sameDiff = sameDiff;
        this.inPlace = inPlace;
        setInstanceId();
        if(sameDiff != null && args != null) {
            sameDiff.addArgsFor(args, this);
        }
    }

    public void replaceArg(int i, SDVariable newArg){
        if(sameDiff != null){
            sameDiff.replaceArgFor(i, newArg, this);
        }
    }


    /**
     * Return the output variables for this differential function.
     * Note that this op *may* dynamically generate variable outputs.
     * @return
     */
    public  SDVariable[] outputVariables() {
        return outputVariables(getOwnName() != null ? getOwnName() : opName());
    }

    /**
     * @return The output variable, or the first output variable, if multiple outputs exist
     */
    public SDVariable outputVariable(){
        return outputVariables()[0];
    }

    public List outputs(){
        SDVariable[] out = outputVariables();
        return out == null ? null : Arrays.asList(out);
    }


    public String[] outputVariablesNames(){
        SDVariable[] outputVars = outputVariables();
        String[] out = new String[outputVars.length];
        for( int i=0; i doDiff(List f1);


    /**
     * Return the arguments for a given function
     * @return the arguments for a given function
     */
    public  SDVariable[] args() {
        return sameDiff == null ? null : sameDiff.getInputVariablesForOp(this);
    }

    /**
     * Return the specified argument for this function
     * @param num Number of the argument. Must be in range 0 to numArgs - 1 inclusive
     * @return Specified argument
     */
    public SDVariable arg(int num){
        SDVariable[] args = args();
        Preconditions.checkNotNull(args, "Arguments are null for function %s", this.getOwnName());
        Preconditions.checkArgument(num >= 0 && num < args.length, "Invalid index: must be 0 to numArgs (0 <= idx < %s), got %s", args.length, num);
        return args[num];
    }

    public String[] argNames(){
        SDVariable[] args = args();
        String[] out = new String[args.length];
        for( int i=0; i diff(List i_v1) {
        List vals = doDiff(i_v1);
        if(vals == null){
            throw new IllegalStateException("Error executing diff operation: doDiff returned null for op: " + this.opName());
        }

        val outputVars = args();
        boolean copied = false;
        for(int i = 0; i < vals.size(); i++) {
            SDVariable var = outputVars[i];
            SDVariable grad = var.hasGradient() ? var.getGradient() : null;
            if(grad != null) {
                if(!copied){
                    //Don't mutate the original - this could mess with the original op's state!
                    vals = new ArrayList<>(vals);
                    copied = true;
                }

                SDVariable gradVar =  var.getSameDiff().math.add(grad, vals.get(i));
                vals.set(i, gradVar);
                sameDiff.setGradientForVariableName(var.name(), gradVar);
            } else {
                SDVariable gradVar = vals.get(i);

                sameDiff.updateVariableNameAndReference(gradVar,var.name() + "-grad");
                sameDiff.setGradientForVariableName(var.name(), gradVar);

            }
        }

        return vals;
    }


    protected void setInstanceId() {
        if(ownName == null) {
            if(sameDiff == null)
                this.ownName = UUID.randomUUID().toString();
            else {
                String n = sameDiff.getOpName(opName());
                this.ownName = n;
            }

            if(sameDiff != null)
                sameDiff.putOpForId(ownName,this);
        }
    }


    /**
     * The name of the op
     * @return
     */
    public String opName() {
        throw new UnsupportedOperationException();
    }


    /**
     * The type of the op
     * @return
     */
    public Op.Type opType() {
        throw new UnsupportedOperationException();
    }


    /**
     * The number of the op (mainly for old legacy XYZ ops
     * like {@link Op})
     * @return
     */
    public int opNum() {
        throw new UnsupportedOperationException();
    }

    @JsonIgnore
    public INDArray getInputArgument(int index){
        //Subclasses should implement this
        throw new UnsupportedOperationException("Not implemented");
    }



    /**
     * Initialize the function from the given
     * {@link NodeDef}
     * @param nodeDef
     * @param initWith
     * @param attributesForNode
     * @param graph
     */
    public abstract void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph);

    /**
     * Iniitialize the function from the given
     * {@link onnx.Onnx.NodeProto}
     * @param node
     * @param initWith
     * @param attributesForNode
     * @param graph
     */
    public abstract void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph);



    /**
     * The left argument for this function
     * @return
     */
    public SDVariable larg() {
        val args = args();
        if(args == null || args.length == 0)
            throw new ND4JIllegalStateException("No arguments found.");
        return args()[0];
    }

    /**
     * The right argument for this function.
     * Note that this assumes that there are 2 args for this
     * function, if 2 are not set, it throws an
     * {@link ND4JIllegalStateException}
     * @return
     */
    public SDVariable rarg() {
        val args = args();
        if(args == null || args.length != 2)
            throw new ND4JIllegalStateException("In order to use this function, the number of arguments for this function must be 2.");
        return args[1];
    }


    /**
     * Duplicate this function
     * @return
     */
    public  DifferentialFunction dup() {
        return FlatBuffersMapper.cloneViaSerialize(sameDiff, this);
    }





    /**
     * Calculate the output shape for this op
     * @return List of output shape descriptors
     */
    public List calculateOutputShape() {
        throw new ND4JIllegalStateException("calculateOutputShape() method leaked out for [" + this.opName() + "]");
    }

    public List calculateOutputShape(OpContext oc){
        throw new ND4JIllegalStateException("calculateOutputShape(OpContext) method leaked out for [" + this.opName() + "]");
    }

    /**
     * Calculate the data types for the output arrays.
     * Though datatypes can also be inferred from {@link #calculateOutputShape()}, this method differs in that it does not
     * require the input arrays to be populated.
     * This is important as it allows us to do greedy datatype inference for the entire net - even if arrays are not
     * available.
     *
     * @param dataTypes The data types of the inputs
     * @return The data types of the outputs
     */
    public List calculateOutputDataTypes(List dataTypes){
        throw new UnsupportedOperationException("calculateOutputDataTypes() has not been implemented for " + getClass().getName());
    }


    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;

        DifferentialFunction that = (DifferentialFunction) o;

        if (inPlace != that.inPlace) return false;
        if (scalarValue != null ? !scalarValue.equals(that.scalarValue) : that.scalarValue != null) return false;
        if (!Arrays.equals(dimensions, that.dimensions)) return false;
        return ownName != null ? ownName.equals(that.ownName) : that.ownName == null;
    }

    @Override
    public int hashCode() {
        int result = 31;
        result = 31 * result + (inPlace ? 1 : 0);
        result = 31 * result + (scalarValue != null ? scalarValue.hashCode() : 0);
        result = 31 * result + Arrays.hashCode(dimensions);
        result = 31 * result + (ownName != null ? ownName.hashCode() : 0);
        return result;
    }

    /**
     * The opName of this function in onnx
     * @return
     */
    public  String[] onnxNames() {
        return new String[] {onnxName()};
    }

    /**
     * The opName of this function tensorflow
     *
     * @return
     */
    public  String[] tensorflowNames() {
        return new String[] {tensorflowName()};
    }

    /**
     * The opName of this function in onnx
     * @return
     */
    public abstract String onnxName();

    /**
     * The opName of this function tensorflow
     *
     * @return
     */
    public abstract String tensorflowName();

    public int getNumOutputs(){return -1;}

    /**
     * Clear the input and output INDArrays, if any are set
     */
    public abstract void clearArrays();
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy