All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.NeuralNetConfiguration Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.nn.conf;

import com.google.common.collect.Sets;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ClassUtils;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.deeplearning4j.nn.conf.serde.ComputationGraphConfigurationDeserializer;
import org.deeplearning4j.nn.conf.serde.MultiLayerConfigurationDeserializer;
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.reflections.DL4JSubTypesScanner;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.*;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.shade.jackson.databind.*;
import org.nd4j.shade.jackson.databind.deser.BeanDeserializerModifier;
import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass;
import org.nd4j.shade.jackson.databind.jsontype.NamedType;
import org.nd4j.shade.jackson.databind.module.SimpleModule;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.reflections.ReflectionUtils;
import org.reflections.Reflections;
import org.reflections.util.ClasspathHelper;
import org.reflections.util.ConfigurationBuilder;
import org.reflections.util.FilterBuilder;

import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Modifier;
import java.net.URL;
import java.util.*;


/**
 * A Serializable configuration
 * for neural nets that covers per layer parameters
 *
 * @author Adam Gibson
 */
@Data
@NoArgsConstructor
@Slf4j
public class NeuralNetConfiguration implements Serializable, Cloneable {


    /**
     * System property for custom layers, preprocessors, graph vertices etc. Enabled by default.
     * Run JVM with "-Dorg.deeplearning4j.config.custom.enabled=false" to disable classpath scanning for
     * Overriding the default (i.e., disabling) this is only useful if (a) no custom layers/preprocessors etc will be
     * used, and (b) minimizing startup/initialization time for new JVMs is very important.
     * Results are cached, so there is no cost to custom layers after the first network has been constructed.
     */
    public static final String CUSTOM_FUNCTIONALITY = "org.deeplearning4j.config.custom.enabled";

    protected Layer layer;
    @Deprecated
    protected double leakyreluAlpha;
    //batch size: primarily used for conv nets. Will be reinforced if set.
    protected boolean miniBatch = true;
    protected int numIterations;
    //number of line search iterations
    protected int maxNumLineSearchIterations;
    protected long seed;
    protected OptimizationAlgorithm optimizationAlgo;
    //gradient keys used for ensuring order when getting and setting the gradient
    protected List variables = new ArrayList<>();
    //whether to constrain the gradient to unit norm or not
    //adadelta - weight for how much to consider previous history
    protected StepFunction stepFunction;
    protected boolean useRegularization = false;
    protected boolean useDropConnect = false;
    //minimize or maximize objective
    protected boolean minimize = true;
    // Graves LSTM & RNN
    protected Map learningRateByParam = new HashMap<>();
    protected Map l1ByParam = new HashMap<>();
    protected Map l2ByParam = new HashMap<>();
    protected LearningRatePolicy learningRatePolicy = LearningRatePolicy.None;
    protected double lrPolicyDecayRate;
    protected double lrPolicySteps;
    protected double lrPolicyPower;
    protected boolean pretrain;

    // this field defines preOutput cache
    protected CacheMode cacheMode;

    //Counter for the number of parameter updates so far for this layer.
    //Note that this is only used for pretrain layers (RBM, VAE) - MultiLayerConfiguration and ComputationGraphConfiguration
    //contain counters for standard backprop training.
    // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted
    // for Spark and model serialization
    protected int iterationCount = 0;

    private static ObjectMapper mapper = initMapper();
    private static final ObjectMapper mapperYaml = initMapperYaml();
    private static Set> subtypesClassCache = null;


    /**
     * Creates and returns a deep copy of the configuration.
     */
    @Override
    public NeuralNetConfiguration clone() {
        try {
            NeuralNetConfiguration clone = (NeuralNetConfiguration) super.clone();
            if (clone.layer != null)
                clone.layer = clone.layer.clone();
            if (clone.stepFunction != null)
                clone.stepFunction = clone.stepFunction.clone();
            if (clone.variables != null)
                clone.variables = new ArrayList<>(clone.variables);
            if (clone.learningRateByParam != null)
                clone.learningRateByParam = new HashMap<>(clone.learningRateByParam);
            if (clone.l1ByParam != null)
                clone.l1ByParam = new HashMap<>(clone.l1ByParam);
            if (clone.l2ByParam != null)
                clone.l2ByParam = new HashMap<>(clone.l2ByParam);
            return clone;
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public List variables() {
        return new ArrayList<>(variables);
    }

    public List variables(boolean copy) {
        if (copy)
            return variables();
        return variables;
    }

    public void addVariable(String variable) {
        if (!variables.contains(variable)) {
            variables.add(variable);
            setLayerParamLR(variable);
        }
    }

    public void clearVariables() {
        variables.clear();
        l1ByParam.clear();
        l2ByParam.clear();
        learningRateByParam.clear();
    }

    public void resetVariables() {
        for (String s : variables) {
            setLayerParamLR(s);
        }
    }

    public void setLayerParamLR(String variable) {
        double lr = layer.getLearningRateByParam(variable);
        double l1 = layer.getL1ByParam(variable);
        if (Double.isNaN(l1))
            l1 = 0.0; //Not set
        double l2 = layer.getL2ByParam(variable);
        if (Double.isNaN(l2))
            l2 = 0.0; //Not set
        learningRateByParam.put(variable, lr);
        l1ByParam.put(variable, l1);
        l2ByParam.put(variable, l2);
    }

    public double getLearningRateByParam(String variable) {
        return learningRateByParam.get(variable);
    }

    public void setLearningRateByParam(String variable, double rate) {
        learningRateByParam.put(variable, rate);
    }

    public double getL1ByParam(String variable) {
        return l1ByParam.get(variable);
    }

    public double getL2ByParam(String variable) {
        return l2ByParam.get(variable);
    }

    /**
     * Fluent interface for building a list of configurations
     */
    public static class ListBuilder extends MultiLayerConfiguration.Builder {
        private Map layerwise;
        private Builder globalConfig;

        // Constructor
        public ListBuilder(Builder globalConfig, Map layerMap) {
            this.globalConfig = globalConfig;
            this.layerwise = layerMap;
        }

        public ListBuilder(Builder globalConfig) {
            this(globalConfig, new HashMap());
        }

        public ListBuilder backprop(boolean backprop) {
            this.backprop = backprop;

            return this;
        }

        public ListBuilder pretrain(boolean pretrain) {
            this.pretrain = pretrain;
            return this;
        }

        public ListBuilder layer(int ind, Layer layer) {
            if (layerwise.containsKey(ind)) {
                layerwise.get(ind).layer(layer);
            } else {
                layerwise.put(ind, globalConfig.clone().layer(layer));
            }
            return this;
        }

        public Map getLayerwise() {
            return layerwise;
        }

        /**
         * Build the multi layer network
         * based on this neural network and
         * overr ridden parameters
         *
         * @return the configuration to build
         */
        public MultiLayerConfiguration build() {
            List list = new ArrayList<>();
            if (layerwise.isEmpty())
                throw new IllegalStateException("Invalid configuration: no layers defined");
            for (int i = 0; i < layerwise.size(); i++) {
                if (layerwise.get(i) == null) {
                    throw new IllegalStateException("Invalid configuration: layer number " + i
                                    + " not specified. Expect layer " + "numbers to be 0 to " + (layerwise.size() - 1)
                                    + " inclusive (number of layers defined: " + layerwise.size() + ")");
                }
                if (layerwise.get(i).getLayer() == null)
                    throw new IllegalStateException("Cannot construct network: Layer config for" + "layer with index "
                                    + i + " is not defined)");

                //Layer names: set to default, if not set
                if (layerwise.get(i).getLayer().getLayerName() == null) {
                    layerwise.get(i).getLayer().setLayerName("layer" + i);
                }

                list.add(layerwise.get(i).build());
            }
            return new MultiLayerConfiguration.Builder().backprop(backprop).inputPreProcessors(inputPreProcessors)
                            .pretrain(pretrain).backpropType(backpropType).tBPTTForwardLength(tbpttFwdLength)
                            .tBPTTBackwardLength(tbpttBackLength).setInputType(this.inputType)
                            .trainingWorkspaceMode(globalConfig.trainingWorkspaceMode).cacheMode(globalConfig.cacheMode)
                            .inferenceWorkspaceMode(globalConfig.inferenceWorkspaceMode).confs(list).build();
        }

    }

    /**
     * Return this configuration as json
     *
     * @return this configuration represented as json
     */
    public String toYaml() {
        ObjectMapper mapper = mapperYaml();

        try {
            String ret = mapper.writeValueAsString(this);
            return ret;

        } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Create a neural net configuration from json
     *
     * @param json the neural net configuration from json
     * @return
     */
    public static NeuralNetConfiguration fromYaml(String json) {
        ObjectMapper mapper = mapperYaml();
        try {
            NeuralNetConfiguration ret = mapper.readValue(json, NeuralNetConfiguration.class);
            return ret;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Return this configuration as json
     *
     * @return this configuration represented as json
     */
    public String toJson() {
        ObjectMapper mapper = mapper();

        try {
            String ret = mapper.writeValueAsString(this);
            return ret;

        } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Create a neural net configuration from json
     *
     * @param json the neural net configuration from json
     * @return
     */
    public static NeuralNetConfiguration fromJson(String json) {
        ObjectMapper mapper = mapper();
        try {
            NeuralNetConfiguration ret = mapper.readValue(json, NeuralNetConfiguration.class);
            return ret;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Object mapper for serialization of configurations
     *
     * @return
     */
    public static ObjectMapper mapperYaml() {
        return mapperYaml;
    }

    private static ObjectMapper initMapperYaml() {
        ObjectMapper ret = new ObjectMapper(new YAMLFactory());
        configureMapper(ret);
        return ret;
    }

    /**
     * Object mapper for serialization of configurations
     *
     * @return
     */
    public static ObjectMapper mapper() {
        return mapper;
    }

    /**
     * Reinitialize and return the Jackson/json ObjectMapper with additional named types.
     * This can be used to add additional subtypes at runtime (i.e., for JSON mapping with
     * types defined outside of the main DL4J codebase)
     */
    public static ObjectMapper reinitMapperWithSubtypes(Collection additionalTypes) {
        mapper.registerSubtypes(additionalTypes.toArray(new NamedType[additionalTypes.size()]));
        //Recreate the mapper (via copy), as mapper won't use registered subtypes after first use
        mapper = mapper.copy();
        return mapper;
    }

    private static ObjectMapper initMapper() {
        ObjectMapper ret = new ObjectMapper();
        configureMapper(ret);
        return ret;
    }

    private static void configureMapper(ObjectMapper ret) {
        ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
        ret.enable(SerializationFeature.INDENT_OUTPUT);

        SimpleModule customDeserializerModule = new SimpleModule();
        customDeserializerModule.setDeserializerModifier(new BeanDeserializerModifier() {
            @Override
            public JsonDeserializer modifyDeserializer(DeserializationConfig config, BeanDescription beanDesc,
                            JsonDeserializer deserializer) {
                //Use our custom deserializers to handle backward compatibility for updaters -> IUpdater
                if (beanDesc.getBeanClass() == MultiLayerConfiguration.class) {
                    return new MultiLayerConfigurationDeserializer(deserializer);
                } else if (beanDesc.getBeanClass() == ComputationGraphConfiguration.class) {
                    return new ComputationGraphConfigurationDeserializer(deserializer);
                }
                return deserializer;
            }
        });

        ret.registerModule(customDeserializerModule);

        registerSubtypes(ret);
    }

    private static synchronized void registerSubtypes(ObjectMapper mapper) {
        //Register concrete subtypes for JSON serialization

        List> classes = Arrays.>asList(InputPreProcessor.class, ILossFunction.class,
                        IActivation.class, Layer.class, GraphVertex.class, ReconstructionDistribution.class);
        List classNames = new ArrayList<>(6);
        for (Class c : classes)
            classNames.add(c.getName());

        // First: scan the classpath and find all instances of the 'baseClasses' classes
        if (subtypesClassCache == null) {

            //Check system property:
            String prop = System.getProperty(CUSTOM_FUNCTIONALITY);
            if (prop != null && !Boolean.parseBoolean(prop)) {

                subtypesClassCache = Collections.emptySet();
            } else {

                List> interfaces = Arrays.>asList(InputPreProcessor.class, ILossFunction.class,
                                IActivation.class, ReconstructionDistribution.class);
                List> classesList = Arrays.>asList(Layer.class, GraphVertex.class);

                Collection urls = ClasspathHelper.forClassLoader();
                List scanUrls = new ArrayList<>();
                for (URL u : urls) {
                    String path = u.getPath();
                    if (!path.matches(".*/jre/lib/.*jar")) { //Skip JRE/JDK JARs
                        scanUrls.add(u);
                    }
                }

                Reflections reflections = new Reflections(new ConfigurationBuilder().filterInputsBy(new FilterBuilder()
                                .exclude("^(?!.*\\.class$).*$") //Consider only .class files (to avoid debug messages etc. on .dlls, etc
                                //Exclude the following: the assumption here is that no custom functionality will ever be present
                                // under these package name prefixes. These are all common dependencies for DL4J
                                .exclude("^org.nd4j.*").exclude("^org.datavec.*").exclude("^org.bytedeco.*") //JavaCPP
                                .exclude("^com.fasterxml.*")//Jackson
                                .exclude("^org.apache.*") //Apache commons, Spark, log4j etc
                                .exclude("^org.projectlombok.*").exclude("^com.twelvemonkeys.*").exclude("^org.joda.*")
                                .exclude("^org.slf4j.*").exclude("^com.google.*").exclude("^org.reflections.*")
                                .exclude("^ch.qos.*") //Logback
                ).addUrls(scanUrls).setScanners(new DL4JSubTypesScanner(interfaces, classesList)));
                org.reflections.Store store = reflections.getStore();

                Iterable subtypesByName = store.getAll(DL4JSubTypesScanner.class.getSimpleName(), classNames);

                Set> subtypeClasses = Sets.newHashSet(ReflectionUtils.forNames(subtypesByName));
                subtypesClassCache = new HashSet<>();
                for (Class c : subtypeClasses) {
                    if (Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers())) {
                        //log.info("Skipping abstract/interface: {}",c);
                        continue;
                    }
                    subtypesClassCache.add(c);
                }
            }
        }

        //Second: get all currently registered subtypes for this mapper
        Set> registeredSubtypes = new HashSet<>();
        for (Class c : classes) {
            AnnotatedClass ac = AnnotatedClass.construct(c, mapper.getSerializationConfig().getAnnotationIntrospector(),
                            null);
            Collection types =
                            mapper.getSubtypeResolver().collectAndResolveSubtypes(ac, mapper.getSerializationConfig(),
                                            mapper.getSerializationConfig().getAnnotationIntrospector());
            for (NamedType nt : types) {
                registeredSubtypes.add(nt.getType());
            }
        }

        //Third: register all _concrete_ subtypes that are not already registered
        List toRegister = new ArrayList<>();
        for (Class c : subtypesClassCache) {
            //Check if it's concrete or abstract...
            if (Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers())) {
                //log.info("Skipping abstract/interface: {}",c);
                continue;
            }

            if (!registeredSubtypes.contains(c)) {
                String name;
                if (ClassUtils.isInnerClass(c)) {
                    Class c2 = c.getDeclaringClass();
                    name = c2.getSimpleName() + "$" + c.getSimpleName();
                } else {
                    name = c.getSimpleName();
                }
                toRegister.add(new NamedType(c, name));
                if (log.isDebugEnabled()) {
                    for (Class baseClass : classes) {
                        if (baseClass.isAssignableFrom(c)) {
                            log.debug("Registering class for JSON serialization: {} as subtype of {}", c.getName(),
                                            baseClass.getName());
                            break;
                        }
                    }
                }
            }
        }

        mapper.registerSubtypes(toRegister.toArray(new NamedType[toRegister.size()]));
    }

    @Data
    public static class Builder implements Cloneable {
        protected IActivation activationFn = new ActivationSigmoid();
        protected WeightInit weightInit = WeightInit.XAVIER;
        protected double biasInit = 0.0;
        protected Distribution dist = null;
        protected double learningRate = 1e-1;
        protected double biasLearningRate = Double.NaN;
        protected Map learningRateSchedule = null;
        protected double lrScoreBasedDecay;
        protected double l1 = Double.NaN;
        protected double l2 = Double.NaN;
        protected double l1Bias = Double.NaN;
        protected double l2Bias = Double.NaN;
        protected double dropOut = 0;
        @Deprecated
        protected Updater updater = Updater.SGD;
        protected IUpdater iUpdater = new Sgd();
        @Deprecated
        protected double momentum = Double.NaN;
        @Deprecated
        protected Map momentumSchedule = null;
        @Deprecated
        protected double epsilon = Double.NaN;
        @Deprecated
        protected double rho = Double.NaN;
        @Deprecated
        protected double rmsDecay = Double.NaN;
        @Deprecated
        protected double adamMeanDecay = Double.NaN;
        @Deprecated
        protected double adamVarDecay = Double.NaN;
        protected Layer layer;
        @Deprecated
        protected double leakyreluAlpha = 0.01;
        protected boolean miniBatch = true;
        protected int numIterations = 1;
        protected int maxNumLineSearchIterations = 5;
        protected long seed = System.currentTimeMillis();
        protected boolean useRegularization = false;
        protected OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
        protected StepFunction stepFunction = null;
        protected boolean useDropConnect = false;
        protected boolean minimize = true;
        protected GradientNormalization gradientNormalization = GradientNormalization.None;
        protected double gradientNormalizationThreshold = 1.0;
        protected LearningRatePolicy learningRatePolicy = LearningRatePolicy.None;
        protected double lrPolicyDecayRate = Double.NaN;
        protected double lrPolicySteps = Double.NaN;
        protected double lrPolicyPower = Double.NaN;
        protected boolean pretrain = false;

        protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.NONE;
        protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.SEPARATE;
        protected CacheMode cacheMode = CacheMode.NONE;

        protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;

        public Builder() {
            //
        }

        public Builder(NeuralNetConfiguration newConf) {
            if (newConf != null) {
                minimize = newConf.minimize;
                maxNumLineSearchIterations = newConf.maxNumLineSearchIterations;
                layer = newConf.layer;
                numIterations = newConf.numIterations;
                useRegularization = newConf.useRegularization;
                optimizationAlgo = newConf.optimizationAlgo;
                seed = newConf.seed;
                stepFunction = newConf.stepFunction;
                useDropConnect = newConf.useDropConnect;
                miniBatch = newConf.miniBatch;
                learningRatePolicy = newConf.learningRatePolicy;
                lrPolicyDecayRate = newConf.lrPolicyDecayRate;
                lrPolicySteps = newConf.lrPolicySteps;
                lrPolicyPower = newConf.lrPolicyPower;
                pretrain = newConf.pretrain;
            }
        }

        /**
         * Process input as minibatch vs full dataset.
         * Default set to true.
         */
        public Builder miniBatch(boolean miniBatch) {
            this.miniBatch = miniBatch;
            return this;
        }

        /**
         * This method defines Workspace mode being used during training:
         * NONE: workspace won't be used
         * SINGLE: one workspace will be used during whole iteration loop
         * SEPARATE: separate workspaces will be used for feedforward and backprop iteration loops
         *
         * @param workspaceMode
         * @return
         */
        public Builder trainingWorkspaceMode(@NonNull WorkspaceMode workspaceMode) {
            this.trainingWorkspaceMode = workspaceMode;
            return this;
        }

        /**
         * This method defines Workspace mode being used during inference:
         * NONE: workspace won't be used
         * SINGLE: one workspace will be used during whole iteration loop
         * SEPARATE: separate workspaces will be used for feedforward and backprop iteration loops
         *
         * @param workspaceMode
         * @return
         */
        public Builder inferenceWorkspaceMode(@NonNull WorkspaceMode workspaceMode) {
            this.inferenceWorkspaceMode = workspaceMode;
            return this;
        }

        /**
         * This method defines how/if preOutput cache is handled:
         * NONE: cache disabled (default value)
         * HOST: Host memory will be used
         * DEVICE: GPU memory will be used (on CPU backends effect will be the same as for HOST)
         *
         * @param cacheMode
         * @return
         */
        public Builder cacheMode(@NonNull CacheMode cacheMode) {
            this.cacheMode = cacheMode;
            return this;
        }

        /**
         * Use drop connect: multiply the weight by a binomial sampling wrt the dropout probability.
         * Dropconnect probability is set using {@link #dropOut(double)}; this is the probability of retaining a weight
         *
         * @param useDropConnect whether to use drop connect or not
         * @return the
         */
        public Builder useDropConnect(boolean useDropConnect) {
            this.useDropConnect = useDropConnect;
            return this;
        }

        /**
         * Objective function to minimize or maximize cost function
         * Default set to minimize true.
         */
        public Builder minimize(boolean minimize) {
            this.minimize = minimize;
            return this;
        }

        /**
         * Maximum number of line search iterations.
         * Only applies for line search optimizers: Line Search SGD, Conjugate Gradient, LBFGS
         * is NOT applicable for standard SGD
         *
         * @param maxNumLineSearchIterations > 0
         * @return
         */
        public Builder maxNumLineSearchIterations(int maxNumLineSearchIterations) {
            this.maxNumLineSearchIterations = maxNumLineSearchIterations;
            return this;
        }


        /**
         * Layer class.
         */
        public Builder layer(Layer layer) {
            this.layer = layer;
            return this;
        }

        /**
         * Step function to apply for back track line search.
         * Only applies for line search optimizers: Line Search SGD, Conjugate Gradient, LBFGS
         * Options: DefaultStepFunction (default), NegativeDefaultStepFunction
         * GradientStepFunction (for SGD), NegativeGradientStepFunction
         */
        public Builder stepFunction(StepFunction stepFunction) {
            this.stepFunction = stepFunction;
            return this;
        }

        /**
         * Create a ListBuilder (for creating a MultiLayerConfiguration)
* Usage:
*
         * {@code .list()
         * .layer(0,new DenseLayer.Builder()...build())
         * ...
         * .layer(n,new OutputLayer.Builder()...build())
         * }
         * 
*/ public ListBuilder list() { return new ListBuilder(this); } /** * Create a ListBuilder (for creating a MultiLayerConfiguration) with the specified layers
* Usage:
*
         * {@code .list(
         *      new DenseLayer.Builder()...build(),
         *      ...,
         *      new OutputLayer.Builder()...build())
         * }
         * 
* * @param layers The layer configurations for the network */ public ListBuilder list(Layer... layers) { if (layers == null || layers.length == 0) throw new IllegalArgumentException("Cannot create network with no layers"); Map layerMap = new HashMap<>(); for (int i = 0; i < layers.length; i++) { Builder b = this.clone(); b.layer(layers[i]); layerMap.put(i, b); } return new ListBuilder(this, layerMap); } /** * Create a GraphBuilder (for creating a ComputationGraphConfiguration). */ public ComputationGraphConfiguration.GraphBuilder graphBuilder() { return new ComputationGraphConfiguration.GraphBuilder(this); } /** * Number of optimization iterations. */ public Builder iterations(int numIterations) { this.numIterations = numIterations; return this; } /** * Random number generator seed. Used for reproducability between runs */ public Builder seed(int seed) { this.seed = (long) seed; Nd4j.getRandom().setSeed(seed); return this; } /** * Random number generator seed. Used for reproducability between runs */ public Builder seed(long seed) { this.seed = seed; Nd4j.getRandom().setSeed(seed); return this; } /** * Optimization algorithm to use. Most common: OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT * * @param optimizationAlgo Optimization algorithm to use when training */ public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) { this.optimizationAlgo = optimizationAlgo; return this; } /** * Whether to use regularization (l1, l2, dropout, etc */ public Builder regularization(boolean useRegularization) { this.useRegularization = useRegularization; return this; } @Override public Builder clone() { try { Builder clone = (Builder) super.clone(); if (clone.layer != null) clone.layer = clone.layer.clone(); if (clone.stepFunction != null) clone.stepFunction = clone.stepFunction.clone(); return clone; } catch (CloneNotSupportedException e) { throw new RuntimeException(e); } } /** * Activation function / neuron non-linearity * Typical values include:
* "relu" (rectified linear), "tanh", "sigmoid", "softmax", * "hardtanh", "leakyrelu", "maxout", "softsign", "softplus" * * @deprecated Use {@link #activation(Activation)} or * {@link @activation(IActivation)} */ @Deprecated public Builder activation(String activationFunction) { return activation(Activation.fromString(activationFunction).getActivationFunction()); } /** * Activation function / neuron non-linearity * * @see #activation(Activation) */ public Builder activation(IActivation activationFunction) { this.activationFn = activationFunction; return this; } /** * Activation function / neuron non-linearity */ public Builder activation(Activation activation) { return activation(activation.getActivationFunction()); } /** * @deprecated Use {@link #activation(IActivation)} with leaky relu, setting alpha value directly in constructor. */ @Deprecated public Builder leakyreluAlpha(double leakyreluAlpha) { this.leakyreluAlpha = leakyreluAlpha; return this; } /** * Weight initialization scheme. * * @see org.deeplearning4j.nn.weights.WeightInit */ public Builder weightInit(WeightInit weightInit) { this.weightInit = weightInit; return this; } /** * Constant for bias initialization. Default: 0.0 * * @param biasInit Constant for bias initialization */ public Builder biasInit(double biasInit) { this.biasInit = biasInit; return this; } /** * Distribution to sample initial weights from. Used in conjunction with * .weightInit(WeightInit.DISTRIBUTION). */ public Builder dist(Distribution dist) { this.dist = dist; return this; } /** * Learning rate. Defaults to 1e-1 */ public Builder learningRate(double learningRate) { this.learningRate = learningRate; return this; } /** * Bias learning rate. Set this to apply a different learning rate to the bias */ public Builder biasLearningRate(double biasLearningRate) { this.biasLearningRate = biasLearningRate; return this; } /** * Learning rate schedule. Map of the iteration to the learning rate to apply at that iteration. */ public Builder learningRateSchedule(Map learningRateSchedule) { this.learningRateSchedule = learningRateSchedule; return this; } /** * Rate to decrease learningRate by when the score stops improving. * Learning rate is multiplied by this rate so ideally keep between 0 and 1. */ public Builder learningRateScoreBasedDecayRate(double lrScoreBasedDecay) { this.lrScoreBasedDecay = lrScoreBasedDecay; return this; } /** * L1 regularization coefficient for the weights. * Use with .regularization(true) */ public Builder l1(double l1) { this.l1 = l1; return this; } /** * L2 regularization coefficient for the weights. * Use with .regularization(true) */ public Builder l2(double l2) { this.l2 = l2; return this; } /** * L1 regularization coefficient for the bias. * Use with .regularization(true) */ public Builder l1Bias(double l1Bias) { this.l1Bias = l1Bias; return this; } /** * L2 regularization coefficient for the bias. * Use with .regularization(true) */ public Builder l2Bias(double l2Bias) { this.l2Bias = l2Bias; return this; } /** * Dropout probability. This is the probability of retaining an activation. So dropOut(x) will keep an * activation with probability x, and set to 0 with probability 1-x.
* dropOut(0.0) is disabled (default). *

* Note: This sets the probability per-layer. Care should be taken when setting lower values for complex networks. *

* * @param dropOut Dropout probability (probability of retaining an activation) */ public Builder dropOut(double dropOut) { this.dropOut = dropOut; return this; } /** * Momentum rate * Used only when Updater is set to {@link Updater#NESTEROVS} * * @deprecated Use {@code .updater(new Nesterov(momentum))} instead */ @Deprecated public Builder momentum(double momentum) { this.momentum = momentum; return this; } /** * Momentum schedule. Map of the iteration to the momentum rate to apply at that iteration * Used only when Updater is set to {@link Updater#NESTEROVS} * * @deprecated Use {@code .updater(Nesterov.builder().momentumSchedule(schedule).build())} instead */ @Deprecated public Builder momentumAfter(Map momentumAfter) { this.momentumSchedule = momentumAfter; return this; } /** * Gradient updater. For example, Updater.SGD for standard stochastic gradient descent, * Updater.NESTEROV for Nesterov momentum, Updater.RSMPROP for RMSProp, etc.
* Note: default hyperparameters are used with this method. Use {@link #updater(IUpdater)} to configure * the updater-specific hyperparameters. * * @see Updater */ public Builder updater(Updater updater) { this.updater = updater; return updater(updater.getIUpdaterWithDefaultConfig()); } /** * Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} * or {@link org.nd4j.linalg.learning.config.Nesterovs} * * @param updater Updater to use */ public Builder updater(IUpdater updater) { //Ensure legacy field is set... if(updater instanceof Sgd) this.updater = Updater.SGD; else if(updater instanceof Adam) this.updater = Updater.ADAM; else if(updater instanceof AdaMax) this.updater = Updater.ADAMAX; else if(updater instanceof AdaDelta) this.updater = Updater.ADADELTA; else if(updater instanceof Nesterovs) this.updater = Updater.NESTEROVS; else if(updater instanceof Nadam) this.updater = Updater.NADAM; else if(updater instanceof AdaGrad) this.updater = Updater.ADAGRAD; else if(updater instanceof RmsProp) this.updater = Updater.RMSPROP; else if(updater instanceof NoOp) this.updater = Updater.NONE; this.iUpdater = updater; return this; } /** * Ada delta coefficient * * @param rho * @deprecated use {@code .updater(new AdaDelta(rho,epsilon))} intead */ @Deprecated public Builder rho(double rho) { this.rho = rho; return this; } /** * Epsilon value for updaters: Adam, RMSProp, Adagrad, Adadelta * * @param epsilon Epsilon value to use for adagrad or * @deprecated Use use {@code .updater(Adam.builder().epsilon(epsilon).build())} or similar instead */ @Deprecated public Builder epsilon(double epsilon) { this.epsilon = epsilon; return this; } /** * Decay rate for RMSProp. Only applies if using .updater(Updater.RMSPROP) * * @deprecated use {@code .updater(new RmsProp(rmsDecay))} intead */ @Deprecated public Builder rmsDecay(double rmsDecay) { this.rmsDecay = rmsDecay; return this; } /** * Mean decay rate for Adam updater. Only applies if using .updater(Updater.ADAM) * * @deprecated use {@code .updater(Adam.builder().beta1(adamMeanDecay).build())} intead */ @Deprecated public Builder adamMeanDecay(double adamMeanDecay) { this.adamMeanDecay = adamMeanDecay; return this; } /** * Variance decay rate for Adam updater. Only applies if using .updater(Updater.ADAM) * * @deprecated use {@code .updater(Adam.builder().beta2(adamVarDecay).build())} intead */ @Deprecated public Builder adamVarDecay(double adamVarDecay) { this.adamVarDecay = adamVarDecay; return this; } /** * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc. * * @param gradientNormalization Type of normalization to use. Defaults to None. * @see GradientNormalization */ public Builder gradientNormalization(GradientNormalization gradientNormalization) { this.gradientNormalization = gradientNormalization; return this; } /** * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue
* Not used otherwise.
* L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping. */ public Builder gradientNormalizationThreshold(double threshold) { this.gradientNormalizationThreshold = threshold; return this; } /** * Learning rate decay policy. Used to adapt learning rate based on policy. * * @param policy Type of policy to use. Defaults to None. */ public Builder learningRateDecayPolicy(LearningRatePolicy policy) { this.learningRatePolicy = policy; return this; } /** * Set the decay rate for the learning rate decay policy. * * @param lrPolicyDecayRate rate. */ public Builder lrPolicyDecayRate(double lrPolicyDecayRate) { this.lrPolicyDecayRate = lrPolicyDecayRate; return this; } /** * Set the number of steps used for learning decay rate steps policy. * * @param lrPolicySteps number of steps */ public Builder lrPolicySteps(double lrPolicySteps) { this.lrPolicySteps = lrPolicySteps; return this; } /** * Set the power used for learning rate inverse policy. * * @param lrPolicyPower power */ public Builder lrPolicyPower(double lrPolicyPower) { this.lrPolicyPower = lrPolicyPower; return this; } public Builder convolutionMode(ConvolutionMode convolutionMode) { this.convolutionMode = convolutionMode; return this; } private void learningRateValidation(String layerName) { if (learningRatePolicy != LearningRatePolicy.None && Double.isNaN(lrPolicyDecayRate)) { //LR policy, if used, should have a decay rate. 2 exceptions: Map for schedule, and Poly + power param if (!(learningRatePolicy == LearningRatePolicy.Schedule && learningRateSchedule != null) && !(learningRatePolicy == LearningRatePolicy.Poly && !Double.isNaN(lrPolicyPower))) throw new IllegalStateException("Layer \"" + layerName + "\" learning rate policy decay rate (lrPolicyDecayRate) must be set to use learningRatePolicy."); } switch (learningRatePolicy) { case Inverse: case Poly: if (Double.isNaN(lrPolicyPower)) throw new IllegalStateException("Layer \"" + layerName + "\" learning rate policy power (lrPolicyPower) must be set to use " + learningRatePolicy); break; case Step: case Sigmoid: if (Double.isNaN(lrPolicySteps)) throw new IllegalStateException("Layer \"" + layerName + "\" learning rate policy steps (lrPolicySteps) must be set to use " + learningRatePolicy); break; case Schedule: if (learningRateSchedule == null) throw new IllegalStateException("Layer \"" + layerName + "\" learning rate policy schedule (learningRateSchedule) must be set to use " + learningRatePolicy); break; } if (!Double.isNaN(lrPolicyPower) && (learningRatePolicy != LearningRatePolicy.Inverse && learningRatePolicy != LearningRatePolicy.Poly)) throw new IllegalStateException("Layer \"" + layerName + "\" power has been set but will not be applied unless the learning rate policy is set to Inverse or Poly."); if (!Double.isNaN(lrPolicySteps) && (learningRatePolicy != LearningRatePolicy.Step && learningRatePolicy != LearningRatePolicy.Sigmoid && learningRatePolicy != LearningRatePolicy.TorchStep)) throw new IllegalStateException("Layer \"" + layerName + "\" steps have been set but will not be applied unless the learning rate policy is set to Step or Sigmoid."); if ((learningRateSchedule != null) && (learningRatePolicy != LearningRatePolicy.Schedule)) throw new IllegalStateException("Layer \"" + layerName + "\" learning rate schedule has been set but will not be applied unless the learning rate policy is set to Schedule."); } //////////////// /** * Return a configuration based on this builder * * @return */ public NeuralNetConfiguration build() { NeuralNetConfiguration conf = new NeuralNetConfiguration(); conf.minimize = minimize; conf.maxNumLineSearchIterations = maxNumLineSearchIterations; conf.layer = layer; conf.numIterations = numIterations; conf.useRegularization = useRegularization; conf.optimizationAlgo = optimizationAlgo; conf.seed = seed; conf.stepFunction = stepFunction; conf.useDropConnect = useDropConnect; conf.miniBatch = miniBatch; conf.learningRatePolicy = learningRatePolicy; conf.lrPolicyDecayRate = lrPolicyDecayRate; conf.lrPolicySteps = lrPolicySteps; conf.lrPolicyPower = lrPolicyPower; conf.pretrain = pretrain; conf.cacheMode = this.cacheMode; configureLayer(layer); if (layer instanceof FrozenLayer) { configureLayer(((FrozenLayer) layer).getLayer()); } return conf; } private void configureLayer(Layer layer) { String layerName; if (layer == null || layer.getLayerName() == null) layerName = "Layer not named"; else layerName = layer.getLayerName(); learningRateValidation(layerName); if (layer != null) { copyConfigToLayer(layerName, layer); } if (layer instanceof FrozenLayer) { copyConfigToLayer(layerName, ((FrozenLayer) layer).getLayer()); } if (layer instanceof ConvolutionLayer) { ConvolutionLayer cl = (ConvolutionLayer) layer; if (cl.getConvolutionMode() == null) { cl.setConvolutionMode(convolutionMode); } } if (layer instanceof SubsamplingLayer) { SubsamplingLayer sl = (SubsamplingLayer) layer; if (sl.getConvolutionMode() == null) { sl.setConvolutionMode(convolutionMode); } } LayerValidation.generalValidation(layerName, layer, useRegularization, useDropConnect, dropOut, l2, l2Bias, l1, l1Bias, dist); } private void copyConfigToLayer(String layerName, Layer layer) { if (Double.isNaN(layer.getDropOut())) layer.setDropOut(dropOut); if (layer instanceof BaseLayer) { BaseLayer bLayer = (BaseLayer) layer; if (Double.isNaN(bLayer.getLearningRate())) bLayer.setLearningRate(learningRate); if (Double.isNaN(bLayer.getBiasLearningRate())) { //Two possibilities when bias LR isn't set for layer: // (a) If global bias LR *is* set -> set it to that // (b) Otherwise, set to layer LR (and, by extension, the global LR) if (!Double.isNaN(biasLearningRate)) { //Global bias LR is set bLayer.setBiasLearningRate(biasLearningRate); } else { bLayer.setBiasLearningRate(bLayer.getLearningRate()); } } if (bLayer.getLearningRateSchedule() == null) bLayer.setLearningRateSchedule(learningRateSchedule); if (Double.isNaN(bLayer.getL1())) bLayer.setL1(l1); if (Double.isNaN(bLayer.getL2())) bLayer.setL2(l2); if (bLayer.getActivationFn() == null) bLayer.setActivationFn(activationFn); if (bLayer.getWeightInit() == null) bLayer.setWeightInit(weightInit); if (Double.isNaN(bLayer.getBiasInit())) bLayer.setBiasInit(biasInit); if (bLayer.getUpdater() == null) bLayer.setUpdater(updater); if (bLayer.getIUpdater() == null) { bLayer.setIUpdater(iUpdater.clone()); } LayerValidation.updaterValidation(layerName, layer, learningRate, momentum, momentumSchedule, adamMeanDecay, adamVarDecay, rho, rmsDecay, epsilon); if (bLayer.getGradientNormalization() == null) bLayer.setGradientNormalization(gradientNormalization); if (Double.isNaN(bLayer.getGradientNormalizationThreshold())) bLayer.setGradientNormalizationThreshold(gradientNormalizationThreshold); } } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy