All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.MultiLayerSpace Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
package org.deeplearning4j.arbiter;

import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.CollectionUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

import java.util.ArrayList;
import java.util.List;

//public class MultiLayerSpace implements ModelParameterSpace {
public class MultiLayerSpace extends BaseNetworkSpace {

    @Deprecated
    private ParameterSpace cnnInputSize;
    private List layerSpaces = new ArrayList<>();
    private ParameterSpace inputType;

    //Early stopping configuration / (fixed) number of epochs:
    private EarlyStoppingConfiguration earlyStoppingConfiguration;

    private int numParameters;

    private MultiLayerSpace(Builder builder) {
        super(builder);
        this.cnnInputSize = builder.cnnInputSize;
        this.inputType = builder.inputType;

        this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;

        this.layerSpaces = builder.layerSpaces;

        //Determine total number of parameters:
        List list = CollectionUtils.getUnique(collectLeaves());
        for (ParameterSpace ps : list) numParameters += ps.numParameters();

        //TODO inputs
    }

    @Override
    public DL4JConfiguration getValue(double[] values) {

        //First: create layer configs
        List layers = new ArrayList<>();
        for (LayerConf c : layerSpaces) {
            int n = c.numLayers.getValue(values);
            if (c.duplicateConfig) {
                //Generate N identical configs
                org.deeplearning4j.nn.conf.layers.Layer l = c.layerSpace.getValue(values);
                for (int i = 0; i < n; i++) {
                    layers.add(l.clone());
                }
            } else {
                throw new UnsupportedOperationException("Not yet implemented");
//                //Generate N indepedent configs
//                for( int i=0; i collectLeaves() {
        List list = super.collectLeaves();
        for (LayerConf lc : layerSpaces) {
            list.addAll(lc.numLayers.collectLeaves());
            list.addAll(lc.layerSpace.collectLeaves());
        }
        if (cnnInputSize != null) list.addAll(cnnInputSize.collectLeaves());
        if (inputType != null) list.addAll(inputType.collectLeaves());
        return list;
    }


    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder(super.toString());

        int i = 0;
        for (LayerConf conf : layerSpaces) {

            sb.append("Layer config ").append(i++).append(": (Number layers:").append(conf.numLayers)
                    .append(", duplicate: ").append(conf.duplicateConfig).append("), ")
                    .append(conf.layerSpace.toString()).append("\n");
        }

        if (cnnInputSize != null) sb.append("cnnInputSize: ").append(cnnInputSize).append("\n");
        if (inputType != null) sb.append("inputType: ").append(inputType).append("\n");

        if (earlyStoppingConfiguration != null) {
            sb.append("Early stopping configuration:").append(earlyStoppingConfiguration.toString()).append("\n");
        } else {
            sb.append("Training # epochs:").append(numEpochs).append("\n");
        }

        return sb.toString();
    }

    @AllArgsConstructor
    private static class LayerConf {
        private final LayerSpace layerSpace;
        private final ParameterSpace numLayers;
        private final boolean duplicateConfig;
    }

    public static class Builder extends BaseNetworkSpace.Builder {

        @Deprecated
        private ParameterSpace cnnInputSize;
        private List layerSpaces = new ArrayList<>();
        private ParameterSpace inputType;

        //Early stopping configuration
        private EarlyStoppingConfiguration earlyStoppingConfiguration;


        @Deprecated
        public Builder cnnInputSize(int height, int width, int depth) {
            return cnnInputSize(new FixedValue<>(new int[]{height, width, depth}));
        }

        @Deprecated
        public Builder cnnInputSize(ParameterSpace cnnInputSize) {
            this.cnnInputSize = cnnInputSize;
            return this;
        }

        public Builder setInputType(InputType inputType) {
            return setInputType(new FixedValue<>(inputType));
        }

        public Builder setInputType(ParameterSpace inputType) {
            this.inputType = inputType;
            return this;
        }


        public Builder addLayer(LayerSpace layerSpace) {
            return addLayer(layerSpace, new FixedValue<>(1), true);
        }

        /**
         * @param layerSpace
         * @param numLayersDistribution Distribution for number of layers to generate
         * @param duplicateConfig       Only used if more than 1 layer can be generated. If true: generate N identical (stacked) layers.
         *                              If false: generate N independent layers
         */
        public Builder addLayer(LayerSpace layerSpace,
                                ParameterSpace numLayersDistribution, boolean duplicateConfig) {
            layerSpaces.add(new LayerConf(layerSpace, numLayersDistribution, duplicateConfig));
            return this;
        }

        /**
         * Early stopping configuration (optional). Note if both EarlyStoppingConfiguration and number of epochs is
         * present, early stopping will be used in preference.
         */
        public Builder earlyStoppingConfiguration(EarlyStoppingConfiguration earlyStoppingConfiguration) {
            this.earlyStoppingConfiguration = earlyStoppingConfiguration;
            return this;
        }

        @SuppressWarnings("unchecked")
        public MultiLayerSpace build() {
            return new MultiLayerSpace(this);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy