All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.transform.SubGraph Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.samediff.transform;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.Variable;

import java.util.*;

@AllArgsConstructor
@NoArgsConstructor
@Builder
@Data
public class SubGraph {

    protected SameDiff sameDiff;
    protected DifferentialFunction rootNode;
    protected List childNodes;


    public List outputs(){
        //Outputs: the SDVariables of the root OR child nodes that are not consumed *ONLY* by another op within the subgraph
        List allOutputs = new ArrayList<>();
        if(rootNode.outputVariables() != null)
            Collections.addAll(allOutputs, rootNode.outputVariables());
        if(childNodes != null && !childNodes.isEmpty()){

            Set seenAsInput = new HashSet<>();
            if(rootNode.args() != null)
                Collections.addAll(seenAsInput, rootNode.args());

            for(DifferentialFunction df : childNodes){
                if(df.args() != null)
                    Collections.addAll(seenAsInput, df.args());
                if(df.outputVariables() != null)
                    Collections.addAll(allOutputs, df.outputVariables());
            }
        }

        //Now: filter all output variables that are consumed *only* by
        //Example subgraph: x -> y -> z... then Y is not an output
        //But suppose same subgraph, but connection y -> a exists; then Y must be an output, because it's used somewhere else
        List filteredOutputs = new ArrayList<>(allOutputs.size());
        for(SDVariable v : allOutputs){
            Variable var = sameDiff.getVariables().get(v.name());
            List inputsFor = var.getInputsForOp();
            boolean allInSubgraph = true;
            if(inputsFor != null){
                for(String opOwnName : inputsFor) {
                    if (!inSubgraph(sameDiff.getOpById(opOwnName))){
                        allInSubgraph = false;
                        break;
                    }
                }
            }
            if(!allInSubgraph){
                filteredOutputs.add(v);
            }
        }

        return filteredOutputs;
    }

    public List inputs(){
        //Inputs: the SDVariables that are inputs to this subgraph are those used by any of the differential functions
        // (root or child nodes) that are NOT outputs of any of the child nodes

        Set outputsOfSubgraphNodes = new HashSet<>();
        for(DifferentialFunction df : allFunctionsInSubgraph()){
            SDVariable[] outputVars = df.outputVariables();
            if(outputVars != null){
                Collections.addAll(outputsOfSubgraphNodes, outputVars);
            }
        }

        List inputs = new ArrayList<>();
        for(DifferentialFunction df : allFunctionsInSubgraph()){
            SDVariable[] args = df.args();
            if(args != null){
                for(SDVariable arg : args){
                    if(!outputsOfSubgraphNodes.contains(arg)){
                        inputs.add(arg);
                    }
                }
            }
        }


        return inputs;
    }

    public boolean inSubgraph(DifferentialFunction df){
        if(rootNode == df)
            return true;
        if(childNodes != null){
            for(DifferentialFunction d : childNodes){
                if(d == df){
                    return true;
                }
            }
        }
        return false;
    }

    public List allFunctionsInSubgraph(){
        List out = new ArrayList<>();
        out.add(rootNode);
        if(childNodes != null){
            out.addAll(childNodes);
        }
        return out;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy