All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.MultiLayerSpace Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
package org.deeplearning4j.arbiter;

import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.CollectionUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

import java.util.ArrayList;
import java.util.List;

//public class MultiLayerSpace implements ModelParameterSpace {
public class MultiLayerSpace extends BaseNetworkSpace {

    private ParameterSpace cnnInputSize;
    private List layerSpaces = new ArrayList<>();

    //Early stopping configuration / (fixed) number of epochs:
    private EarlyStoppingConfiguration earlyStoppingConfiguration;

    private int numParameters;

    private MultiLayerSpace(Builder builder){
        super(builder);
        this.cnnInputSize = builder.cnnInputSize;

        this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;

        this.layerSpaces = builder.layerSpaces;

        //Determine total number of parameters:
        List list = CollectionUtils.getUnique(collectLeaves());
        for(ParameterSpace ps : list) numParameters += ps.numParameters();

        //TODO inputs
    }

    @Override
    public DL4JConfiguration getValue(double[] values) {

        //First: create layer configs
        List layers = new ArrayList<>();
        for(LayerConf c : layerSpaces){
            int n = c.numLayers.getValue(values);
            if(c.duplicateConfig){
                //Generate N identical configs
                org.deeplearning4j.nn.conf.layers.Layer l = c.layerSpace.getValue(values);
                for( int i=0; i collectLeaves() {
        List list = super.collectLeaves();
        for(LayerConf lc : layerSpaces){
            list.addAll(lc.numLayers.collectLeaves());
            list.addAll(lc.layerSpace.collectLeaves());
        }
        if(cnnInputSize != null) list.addAll(cnnInputSize.collectLeaves());
        return list;
    }



    @Override
    public String toString(){
        StringBuilder sb = new StringBuilder(super.toString());

        int i=0;
        for(LayerConf conf : layerSpaces){

            sb.append("Layer config ").append(i++).append(": (Number layers:").append(conf.numLayers)
                    .append(", duplicate: ").append(conf.duplicateConfig).append("), ")
                    .append(conf.layerSpace.toString()).append("\n");
        }

        if(earlyStoppingConfiguration != null){
            sb.append("Early stopping configuration:").append(earlyStoppingConfiguration.toString()).append("\n");
        } else {
            sb.append("Training # epochs:").append(numEpochs).append("\n");
        }

        return sb.toString();
    }

    @AllArgsConstructor
    private static class LayerConf {
        private final LayerSpace layerSpace;
        private final ParameterSpace numLayers;
        private final boolean duplicateConfig;
    }

    public static class Builder extends BaseNetworkSpace.Builder {

        private ParameterSpace cnnInputSize;

        private List layerSpaces = new ArrayList<>();

        //Early stopping configuration
        private EarlyStoppingConfiguration earlyStoppingConfiguration;


        public Builder cnnInputSize(int height, int width, int depth){
            return cnnInputSize(new FixedValue<>(new int[]{height, width, depth}));
        }

        public Builder cnnInputSize(ParameterSpace cnnInputSize){
            this.cnnInputSize = cnnInputSize;
            return this;
        }


        public Builder addLayer(LayerSpace layerSpace){
            return addLayer(layerSpace,new FixedValue<>(1),true);
        }

        /**
         * @param layerSpace
         * @param numLayersDistribution Distribution for number of layers to generate
         * @param duplicateConfig Only used if more than 1 layer can be generated. If true: generate N identical (stacked) layers.
         *                        If false: generate N independent layers
         */
        public Builder addLayer(LayerSpace layerSpace,
                                ParameterSpace numLayersDistribution, boolean duplicateConfig){
            layerSpaces.add(new LayerConf(layerSpace,numLayersDistribution,duplicateConfig));
            return this;
        }

        /** Early stopping configuration (optional). Note if both EarlyStoppingConfiguration and number of epochs is
         * present, early stopping will be used in preference.
         */
        public Builder earlyStoppingConfiguration(EarlyStoppingConfiguration earlyStoppingConfiguration){
            this.earlyStoppingConfiguration = earlyStoppingConfiguration;
            return this;
        }

        @SuppressWarnings("unchecked")
        public MultiLayerSpace build(){
            return new MultiLayerSpace(this);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy