All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ops.custom.Invoke Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */
package org.nd4j.linalg.api.ops.custom;

import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.config.ExecutionResult;
import org.nd4j.autodiff.samediff.config.SDValue;
import org.nd4j.autodiff.samediff.config.SDValueType;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;

import java.util.*;

/**
 * Invoke is an op
 */
public class Invoke extends DynamicCustomOp {

    @Getter
    private String functionName;
    @Getter
    private String[] inputVarNames;
    @Getter
    private String[] outputVarNames;
    @Getter
    private String[] subGraphInputVarNames;
    @Getter
    private String[] subGraphOutputVarNames;

    public Invoke() {
    }

    @Data
    @Builder
    public static class InvokeParams {
        private String functionName;
        private SDVariable[] inputs;
        private String[] inputVarNames;
        private String[] outputVarNames;
        private String[] subGraphInputVarNames;
        private String[] subGraphOutputVarNames;
    }


    public Invoke(SameDiff sameDiff,InvokeParams invokeParams) {
        super(sameDiff,invokeParams.inputs);
        this.sameDiff = sameDiff;
        this.outputVarNames = invokeParams.outputVarNames;
        this.functionName = invokeParams.functionName;
        this.inputVarNames = invokeParams.inputVarNames;
        this.subGraphInputVarNames = invokeParams.subGraphInputVarNames;
        this.subGraphOutputVarNames = invokeParams.subGraphOutputVarNames;
    }

    /**
     * Perform the invoke method.
     * @param op the {@link Invoke} instance to use
     * @param placeHolders the singular placeholders to pass in to the function
     * @param valuePlaceHolders the value placeholders to pass in to the function
     * @return the {@link ExecutionResult} from the sub function
     */
    public static ExecutionResult doInvoke(DifferentialFunction op, Map placeHolders, Map valuePlaceHolders) {
        Invoke invoke = (Invoke) op;
        String funcName = invoke.getFunctionName();
        SameDiff instance = op.getSameDiff().getFunction(funcName);
        //invoke can have state bugs and should not free arrays on its own
        instance.setEnableCache(false);
        SDVariable[] args = op.args();
        String[] inputVarNameMappings = invoke.getInputVarNames();

        String[] subGraphInputNames = invoke.subGraphInputVarNames;
        if(subGraphInputNames == null)
            subGraphInputNames = inputVarNameMappings;

        SDVariable[] outputs = op.outputVariables();

        if(inputVarNameMappings == null) {
            inputVarNameMappings = new String[args.length];
            //default to input names of op unless specified
            for(int i = 0; i < inputVarNameMappings.length; i++) {
                inputVarNameMappings[i] = args[i].name();
            }
        }

        String[] outputVarNameMappings = invoke.getOutputVarNames();
        if(outputVarNameMappings == null) {
            outputVarNameMappings = new String[outputs.length];
            for(int i = 0; i < outputs.length; i++) {
                outputVarNameMappings[i] = outputs[i].name();
            }
        }


        String[] subGraphOutputNames = invoke.subGraphOutputVarNames;
        if(subGraphOutputNames == null)
            subGraphOutputNames = outputVarNameMappings;



        List relevantOutputNames = Arrays.asList(subGraphOutputNames);
        if(valuePlaceHolders.isEmpty()) {
            INDArray[] retOutput = new INDArray[subGraphOutputNames.length];
            Map inputMap = new LinkedHashMap<>();
            for(int i = 0; i < inputVarNameMappings.length; i++) {
                //note that we use the inputs in numerical order ignoring the names
                //this is because the input names aren't aligned with what's passed in
                inputMap.put(subGraphInputNames[i],placeHolders.get(op.argNames()[i]));
            }

            Map output = instance.output(inputMap, relevantOutputNames);
            //note not all keys maybe the same as what we expect so we only add the keys we care about
            int numAdded = 0;
            for(Map.Entry result : output.entrySet()) {
                if(relevantOutputNames.contains(result.getKey())) {
                    retOutput[numAdded] = output.get(result.getKey());
                    numAdded++;
                }
            }

            return ExecutionResult.builder()
                    .outputs(ExecutionResult.pack(output))
                    .build();
        } else {
            Map valueInputs = new LinkedHashMap<>();
            for(int i = 0; i < inputVarNameMappings.length; i++) {
                //note that we use the inputs in numerical order ignoring the names
                //this is because the input names aren't aligned with what's passed in
                valueInputs.put(subGraphInputNames[i],valuePlaceHolders.get(op.argNames()[i]));
            }

            Map valueOutputs = instance.outputValues(valueInputs,relevantOutputNames);
            //rearrange to be in right order for return, this is critical
            Map result = new LinkedHashMap<>();
            for(int i = 0; i < outputVarNameMappings.length; i++) {
                result.put(outputs[i].name(), valueOutputs.get(subGraphOutputNames[i]));
            }

            return ExecutionResult.builder()
                    .valueOutputs(result)
                    .build();

        }

    }

    @Override
    public SDVariable[] outputVariables() {
        if(outputVariables == null) {
            SameDiff func = sameDiff.getFunction(this.functionName);
            if (func == null) {
                throw new IllegalArgumentException("Unable to determine output data types for variables. No function of " + this.functionName + " found!");
            }

            if (subGraphOutputVarNames == null) {
                throw new IllegalStateException("Invalid InvokeConfiguration found. Please specify sub graph output names.");
            }

            SDVariable[] outputs = new SDVariable[subGraphOutputVarNames.length];
            for (int i = 0; i < subGraphOutputVarNames.length; i++) {
                String subGraphVarName = subGraphOutputVarNames[i];
                SDVariable variable = func.getVariable(subGraphVarName);
                if(variable == null) {
                    throw new IllegalStateException("No variable found in sub graph named " + subGraphVarName);
                }
                switch(variable.getVariableType()) {
                    case VARIABLE:
                    case ARRAY:
                    case PLACEHOLDER:
                    case SEQUENCE:
                        if(variable.getShape() != null) {
                            SDVariable clone2 = sameDiff.var(subGraphVarName + "_" + functionName, variable.dataType(), variable.getShape());
                            clone2.setVariableType(VariableType.ARRAY);
                            outputs[i] = clone2;
                        } else { //placeholder shape
                            SDVariable clone2 = sameDiff.var(subGraphVarName + "_" + functionName, variable.dataType());
                            clone2.setVariableType(VariableType.ARRAY);
                            outputs[i] = clone2;
                        }
                        break;
                    case CONSTANT:
                        SDVariable clone2 = sameDiff.var(subGraphVarName + "_" + functionName, variable.dataType());
                        clone2.setVariableType(VariableType.ARRAY);
                        outputs[i] = clone2;
                        break;

                }

            }

            this.outputVariables = outputs;

            if (outputVarNames != null && outputVarNames.length == outputs.length)
                for (int i = 0; i < outputs.length; i++) {
                    if (!outputs[i].name().equals(outputVarNames[i])) {
                        sameDiff.updateVariableNameAndReference(outputs[i], outputVarNames[i], true);
                    }
                }
            else if (this.outputVariables == null) {
                throw new IllegalArgumentException("Invalid configuration for output variable names. Must be equal to the number of outputs.");
            }

            //add outgoing ops after generating output variables
            addOutputsToOp();

            return outputs;
        }
        return outputVariables;
    }

    @Override
    public int getNumOutputs() {
        if(subGraphOutputVarNames != null)
            return subGraphOutputVarNames.length;
        else if(outputVarNames != null)
            return outputVarNames.length;
        return 1;
    }

    @Override
    public String opName() {
        return "invoke";
    }

    @Override
    public void configureFromArguments() {
        super.configureFromArguments();
    }

    @Override
    public void configureWithSameDiff(SameDiff sameDiff) {
        super.configureWithSameDiff(sameDiff);
    }

    @Override
    public List calculateOutputDataTypes(List dataTypes) {
        List ret = new ArrayList<>();
        for(int i = 0; i < getNumOutputs(); i++)
            ret.add(DataType.FLOAT);
        return ret;
    }

    @Override
    public List calculateOutputShape() {
        return Collections.emptyList();
    }


    @Override
    public List calculateOutputShape(OpContext oc) {
        /**
         * TODO: Figure out how to invoke calculate output shape
         * for a graph. This may involve adding a new function
         * to a samediff graph that just calls compute shape for everything.
         */
        List ret = new ArrayList<>();
        for(int i = 0; i < getNumOutputs(); i++) {
            ret.add(LongShapeDescriptor.fromShape(new int[]{1},DataType.DOUBLE));
        }


        return ret;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy