All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.ComputationGraphSpace Maven / Gradle / Ivy

/*
 *  * Copyright 2016 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 */

package org.deeplearning4j.arbiter;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.CollectionUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/** ComputationGraphSpace: Defines the space of valid hyperparameters for a ComputationGraph.
 * Note that this for essentially fixed graph structures only
 */
public class ComputationGraphSpace extends BaseNetworkSpace {

    private List layerSpaces = new ArrayList<>();
    private List vertices = new ArrayList<>();

    private String[] networkInputs;
    private String[] networkOutputs;
    private InputType[] inputTypes;

    private int numParameters;

    //Early stopping configuration / (fixed) number of epochs:
    private EarlyStoppingConfiguration earlyStoppingConfiguration;

    private ComputationGraphSpace(Builder builder){
        super(builder);

        this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;
        this.layerSpaces = builder.layerList;
        this.vertices = builder.vertexList;

        this.networkInputs = builder.networkInputs;
        this.networkOutputs = builder.networkOutputs;
        this.inputTypes = builder.inputTypes;

        //Determine total number of parameters:
        List list = CollectionUtils.getUnique(collectLeaves());
        for(ParameterSpace ps : list) numParameters += ps.numParameters();
    }


    @Override
    public GraphConfiguration getValue(double[] values) {
        //Create ComputationGraphConfiguration...
        NeuralNetConfiguration.Builder builder = randomGlobalConf(values);

        ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder();
        graphBuilder.addInputs(this.networkInputs);
        graphBuilder.setOutputs(this.networkOutputs);
        if(inputTypes != null && inputTypes.length > 0) graphBuilder.setInputTypes(inputTypes);

        //Build/add our layers and vertices:
        for(LayerConf c : layerSpaces){
            org.deeplearning4j.nn.conf.layers.Layer l = c.layerSpace.getValue(values);
            graphBuilder.addLayer(c.getLayerName(),l,c.getInputs());
        }
        for(VertexConf gv : vertices){
            graphBuilder.addVertex(gv.getVertexName(),gv.getGraphVertex(),gv.getInputs());
        }
        

        if(backprop != null) graphBuilder.backprop(backprop.getValue(values));
        if(pretrain != null) graphBuilder.pretrain(pretrain.getValue(values));
        if(backpropType != null) graphBuilder.backpropType(backpropType.getValue(values));
        if(tbpttFwdLength != null) graphBuilder.tBPTTForwardLength(tbpttFwdLength.getValue(values));
        if(tbpttBwdLength != null) graphBuilder.tBPTTBackwardLength(tbpttBwdLength.getValue(values));

        ComputationGraphConfiguration configuration = graphBuilder.build();
        return new GraphConfiguration(configuration,earlyStoppingConfiguration,numEpochs);
    }

    @Override
    public int numParameters() {
        return numParameters;
    }

    @Override
    public List collectLeaves() {
        List list = super.collectLeaves();
        for(LayerConf lc : layerSpaces){
            list.addAll(lc.layerSpace.collectLeaves());
        }
        if(cnnInputSize != null) list.addAll(cnnInputSize.collectLeaves());
        return list;
    }


    @Override
    public String toString(){
        StringBuilder sb = new StringBuilder(super.toString());

        for(LayerConf conf : layerSpaces){
            sb.append("Layer config: \"").append(conf.layerName).append("\", ").append(conf.layerSpace)
                    .append(", inputs: ").append(conf.inputs == null ? "[]" : Arrays.toString(conf.inputs))
                    .append("\n");
        }

        for(VertexConf conf : vertices ){
            sb.append("GraphVertex: \"").append(conf.vertexName).append("\", ").append(conf.graphVertex)
                    .append(", inputs: ").append(conf.inputs == null ? "[]" : Arrays.toString(conf.inputs))
                    .append("\n");
        }

        if(earlyStoppingConfiguration != null){
            sb.append("Early stopping configuration:").append(earlyStoppingConfiguration.toString()).append("\n");
        } else {
            sb.append("Training # epochs:").append(numEpochs).append("\n");
        }

        return sb.toString();
    }

    @AllArgsConstructor @Data
    private static class LayerConf {
        private final LayerSpace layerSpace;
        private final String layerName;
        private final String[] inputs;
    }
    
    @AllArgsConstructor @Data
    private static class VertexConf {
        private final GraphVertex graphVertex;
        private final String vertexName;
        private final String[] inputs;
    }

    public static class Builder extends BaseNetworkSpace.Builder {
        
        protected List layerList = new ArrayList<>();
        protected List vertexList = new ArrayList<>();
        protected EarlyStoppingConfiguration earlyStoppingConfiguration;
        protected String[] networkInputs;
        protected String[] networkOutputs;
        protected InputType[] inputTypes;

        //Need: input types
        //Early stopping configuration
        //Graph nodes

        /** Early stopping configuration (optional). Note if both EarlyStoppingConfiguration and number of epochs is
         * present, early stopping will be used in preference.
         */
        public Builder earlyStoppingConfiguration(EarlyStoppingConfiguration earlyStoppingConfiguration){
            this.earlyStoppingConfiguration = earlyStoppingConfiguration;
            return this;
        }

        public Builder addLayer(String layerName, LayerSpace layerSpace,
                                String... layerInputs){
            layerList.add(new LayerConf(layerSpace,layerName,layerInputs));
            return this;
        }
        
        public Builder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs ){
            vertexList.add(new VertexConf(vertex,vertexName,vertexInputs));
            return this;
        }

        public Builder addInputs(String... networkInputs){
            this.networkInputs = networkInputs;
            return this;
        }

        public Builder setOutputs(String... networkOutputs){
            this.networkOutputs = networkOutputs;
            return this;
        }

        public Builder setInputTypes(InputType... inputTypes){
            this.inputTypes = inputTypes;
            return this;
        }

        @SuppressWarnings("unchecked")
        public ComputationGraphSpace build(){
            return new ComputationGraphSpace(this);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy