All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.MultiLayerSpace Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
package org.deeplearning4j.arbiter;

import lombok.Data;
import lombok.EqualsAndHashCode;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper;
import org.deeplearning4j.arbiter.util.LeafUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

@Data
@EqualsAndHashCode(callSuper = true)
public class MultiLayerSpace extends BaseNetworkSpace {
    @JsonProperty
    protected ParameterSpace inputType;
    @JsonProperty
    protected ParameterSpace> inputPreProcessors;

    //Early stopping configuration / (fixed) number of epochs:
    @JsonProperty
    protected EarlyStoppingConfiguration earlyStoppingConfiguration;
    @JsonProperty
    protected int numParameters;


    protected MultiLayerSpace(Builder builder) {
        super(builder);
        this.inputType = builder.inputType;
        this.inputPreProcessors = builder.inputPreProcessors;

        this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;

        this.layerSpaces = builder.layerSpaces;

        //Determine total number of parameters:
        //Collect the leaves, and make sure they are unique.
        //Note that the *object instances* must be unique - and consequently we don't want to use .equals(), as
        // this would incorrectly filter out equal range parameter spaces
        List allLeaves = collectLeaves();
        List list = LeafUtils.getUniqueObjects(allLeaves);

        for (ParameterSpace ps : list)
            numParameters += ps.numParameters();


    }

    protected MultiLayerSpace() {
        //Default constructor for Jackson json/yaml serialization
    }

    @Override
    public DL4JConfiguration getValue(double[] values) {
        //First: create layer configs
        List layers = new ArrayList<>();
        for (LayerConf c : layerSpaces) {
            int n = c.numLayers.getValue(values);
            if (c.duplicateConfig) {
                //Generate N identical configs
                org.deeplearning4j.nn.conf.layers.Layer l = c.layerSpace.getValue(values);
                for (int i = 0; i < n; i++) {
                    layers.add(l.clone());
                }
            } else {
                throw new UnsupportedOperationException("Not yet implemented");
            }
        }

        //Create MultiLayerConfiguration...
        NeuralNetConfiguration.Builder builder = randomGlobalConf(values);

        NeuralNetConfiguration.ListBuilder listBuilder = builder.list();
        for (int i = 0; i < layers.size(); i++) {
            listBuilder.layer(i, layers.get(i));
        }

        if (backprop != null)
            listBuilder.backprop(backprop.getValue(values));
        if (pretrain != null)
            listBuilder.pretrain(pretrain.getValue(values));
        if (backpropType != null)
            listBuilder.backpropType(backpropType.getValue(values));
        if (tbpttFwdLength != null)
            listBuilder.tBPTTForwardLength(tbpttFwdLength.getValue(values));
        if (tbpttBwdLength != null)
            listBuilder.tBPTTBackwardLength(tbpttBwdLength.getValue(values));
        if (inputType != null)
            listBuilder.setInputType(inputType.getValue(values));
        if (inputPreProcessors != null)
            listBuilder.setInputPreProcessors(inputPreProcessors.getValue(values));

        MultiLayerConfiguration configuration = listBuilder.build();
        return new DL4JConfiguration(configuration, earlyStoppingConfiguration, numEpochs);
    }

    @Override
    public int numParameters() {
        return numParameters;
    }

    @Override
    public List collectLeaves() {
        List list = super.collectLeaves();
        for (LayerConf lc : layerSpaces) {
            list.addAll(lc.numLayers.collectLeaves());
            list.addAll(lc.layerSpace.collectLeaves());
        }
        if (inputType != null)
            list.addAll(inputType.collectLeaves());
        if (inputPreProcessors != null)
            list.addAll(inputPreProcessors.collectLeaves());
        return list;
    }


    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder(super.toString());

        int i = 0;
        for (LayerConf conf : layerSpaces) {

            sb.append("Layer config ").append(i++).append(": (Number layers:").append(conf.numLayers)
                            .append(", duplicate: ").append(conf.duplicateConfig).append("), ")
                            .append(conf.layerSpace.toString()).append("\n");
        }

        if (inputType != null)
            sb.append("inputType: ").append(inputType).append("\n");
        if (inputPreProcessors != null)
            sb.append("inputPreProcessors: ").append(inputPreProcessors).append("\n");

        if (earlyStoppingConfiguration != null) {
            sb.append("Early stopping configuration:").append(earlyStoppingConfiguration.toString()).append("\n");
        } else {
            sb.append("Training # epochs:").append(numEpochs).append("\n");
        }

        return sb.toString();
    }

    public LayerSpace getLayerSpace(int layerNumber) {
        return layerSpaces.get(layerNumber).getLayerSpace();
    }

    public static class Builder extends BaseNetworkSpace.Builder {
        protected List layerSpaces = new ArrayList<>();
        protected ParameterSpace inputType;
        protected ParameterSpace> inputPreProcessors;

        //Early stopping configuration
        protected EarlyStoppingConfiguration earlyStoppingConfiguration;



        public Builder setInputType(InputType inputType) {
            return setInputType(new FixedValue<>(inputType));
        }

        public Builder setInputType(ParameterSpace inputType) {
            this.inputType = inputType;
            return this;
        }


        public Builder addLayer(LayerSpace layerSpace) {
            return addLayer(layerSpace, new FixedValue<>(1), true);
        }

        /**
         * @param layerSpace
         * @param numLayersDistribution Distribution for number of layers to generate
         * @param duplicateConfig       Only used if more than 1 layer can be generated. If true: generate N identical (stacked) layers.
         *                              If false: generate N independent layers
         */
        public Builder addLayer(LayerSpace layerSpace, ParameterSpace numLayersDistribution,
                        boolean duplicateConfig) {
            String layerName = "layer_" + layerSpaces.size();
            layerSpaces.add(new LayerConf(layerSpace, layerName, null, numLayersDistribution, duplicateConfig));
            return this;
        }

        /**
         * Early stopping configuration (optional). Note if both EarlyStoppingConfiguration and number of epochs is
         * present, early stopping will be used in preference.
         */
        public Builder earlyStoppingConfiguration(
                        EarlyStoppingConfiguration earlyStoppingConfiguration) {
            this.earlyStoppingConfiguration = earlyStoppingConfiguration;
            return this;
        }

        /**
         * @param inputPreProcessors Input preprocessors to set for the model
         */
        public Builder setInputPreProcessors(Map inputPreProcessors) {
            return setInputPreProcessors(new FixedValue<>(inputPreProcessors));
        }

        /**
         * @param inputPreProcessors Input preprocessors to set for the model
         */
        public Builder setInputPreProcessors(ParameterSpace> inputPreProcessors) {
            this.inputPreProcessors = inputPreProcessors;
            return this;
        }

        @SuppressWarnings("unchecked")
        public MultiLayerSpace build() {
            return new MultiLayerSpace(this);
        }
    }

    public static MultiLayerSpace fromJson(String json) {
        try {
            return JsonMapper.getMapper().readValue(json, MultiLayerSpace.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static MultiLayerSpace fromYaml(String yaml) {
        try {
            return YamlMapper.getMapper().readValue(yaml, MultiLayerSpace.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy