All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.optimize.GraphOptimizer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.samediff.optimize;

import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.optimize.debug.OptimizationDebugger;
import org.nd4j.autodiff.samediff.optimize.optimizations.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**
 *
 * @author Alex Black
 */
@Slf4j
public class GraphOptimizer {

    public static List defaultOptimizations() {
        return Arrays.asList(
                new UnusedFunctionOptimizations(),
                new ConstantFunctionOptimizations(),
                new IdentityFunctionOptimizations(),
                new ShapeFunctionOptimizations(),
                new UnusedFunctionOptimizations(),
                new CuDNNFunctionOptimizations()
        );
    }

    public static SameDiff optimize(SameDiff graph, String... requiredOutputs){
        return optimize(graph, Arrays.asList(requiredOutputs));
    }

    public static SameDiff optimize(SameDiff graph, List requiredOutputs){
        return optimize(graph, requiredOutputs, defaultOptimizations());
    }

    public static SameDiff optimize(SameDiff graph, List requiredOutputs, List optimizations) {
        return optimize(graph, requiredOutputs, optimizations, null);
    }

    public static SameDiff optimize(SameDiff graph, List requiredOutputs, List optimizations, OptimizationDebugger debugger){
        //TODO Use required outputs - strip unnecessary graph components

        SameDiff sd = graph.dup();

        ArrayHolder cArr = sd.getConstantArrays();
        ArrayHolder vArr = sd.getVariablesArrays();

        OptimizationHelper h = new OptimizationHelper(graph, new OptimizationConfig());    //TODO defaults for config

        for( int i=0; i<3; i++ ) {  //Run multiple times - one run isn't enough, as some more optimizations may need to be applied to the output of earlier optimizations
            for (OptimizerSet s : optimizations) {
                List l = s.getOptimizers();
                for(Optimizer o : l ){
                    Collection startingOps = new ArrayList<>(sd.getOps().values()); //Create list to avoid concurrent modification exception
                    for(SameDiffOp op : startingOps) {
                        //Because ops might disappear from previous optimization steps, we need to check if the previous op
                        // still exists when iterating...
                        if(!sd.getOps().containsKey(op.getName()))
                            continue;

                        if(debugger != null)
                            debugger.beforeOptimizationCheck(sd, op, o);

                        boolean applied = o.checkAndApply(sd, h, op, cArr, vArr);
                        if(applied) {
                            log.info("Operation was applied: {}", o);
                        }

                        if(debugger != null)
                            debugger.afterOptimizationsCheck(sd, op, o, applied);
                    }
                }
            }
        }

        int constBefore = 0;
        int constAfter = 0;
        int varBefore = 0;
        int varAfter = 0;
        int arrBefore = 0;
        int arrAfter = 0;

        for(SDVariable v : graph.variables()){
            switch(v.getVariableType()){
                case VARIABLE:
                    varBefore++;
                    break;
                case CONSTANT:
                    constBefore++;
                    break;
                case ARRAY:
                    arrBefore++;
                    break;
                case PLACEHOLDER:
                    break;
            }
        }

        for(SDVariable v : sd.variables()){
            switch(v.getVariableType()){
                case VARIABLE:
                    varAfter++;
                    break;
                case CONSTANT:
                    constAfter++;
                    break;
                case ARRAY:
                    arrAfter++;
                    break;
                case PLACEHOLDER:
                    break;
            }
        }


        log.info("Total variables: {} before, {} after", graph.getVariables().size(), sd.getVariables().size());
        log.info("Constant variables: {} before, {} after", constBefore, constAfter);
        log.info("Array type variables: {} before, {} after", arrBefore, arrAfter);
        log.info("Variable type variables: {} before, {} after", varBefore, varAfter);
        log.info("Ops: {} before, {} after", graph.getOps().size(), sd.getOps().size());

        return sd;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy