org.deeplearning4j.nn.conf.NeuralNetConfiguration Maven / Gradle / Ivy
/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.conf;
import com.google.common.collect.Sets;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ClassUtils;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.deeplearning4j.nn.conf.serde.ComputationGraphConfigurationDeserializer;
import org.deeplearning4j.nn.conf.serde.MultiLayerConfigurationDeserializer;
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.reflections.DL4JSubTypesScanner;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.*;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.shade.jackson.databind.*;
import org.nd4j.shade.jackson.databind.deser.BeanDeserializerModifier;
import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass;
import org.nd4j.shade.jackson.databind.jsontype.NamedType;
import org.nd4j.shade.jackson.databind.module.SimpleModule;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.reflections.ReflectionUtils;
import org.reflections.Reflections;
import org.reflections.util.ClasspathHelper;
import org.reflections.util.ConfigurationBuilder;
import org.reflections.util.FilterBuilder;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Modifier;
import java.net.URL;
import java.util.*;
/**
* A Serializable configuration
* for neural nets that covers per layer parameters
*
* @author Adam Gibson
*/
@Data
@NoArgsConstructor
@Slf4j
public class NeuralNetConfiguration implements Serializable, Cloneable {
/**
* System property for custom layers, preprocessors, graph vertices etc. Enabled by default.
* Run JVM with "-Dorg.deeplearning4j.config.custom.enabled=false" to disable classpath scanning for
* Overriding the default (i.e., disabling) this is only useful if (a) no custom layers/preprocessors etc will be
* used, and (b) minimizing startup/initialization time for new JVMs is very important.
* Results are cached, so there is no cost to custom layers after the first network has been constructed.
*/
public static final String CUSTOM_FUNCTIONALITY = "org.deeplearning4j.config.custom.enabled";
protected Layer layer;
@Deprecated
protected double leakyreluAlpha;
//batch size: primarily used for conv nets. Will be reinforced if set.
protected boolean miniBatch = true;
protected int numIterations;
//number of line search iterations
protected int maxNumLineSearchIterations;
protected long seed;
protected OptimizationAlgorithm optimizationAlgo;
//gradient keys used for ensuring order when getting and setting the gradient
protected List variables = new ArrayList<>();
//whether to constrain the gradient to unit norm or not
//adadelta - weight for how much to consider previous history
protected StepFunction stepFunction;
protected boolean useRegularization = false;
protected boolean useDropConnect = false;
//minimize or maximize objective
protected boolean minimize = true;
// Graves LSTM & RNN
protected Map learningRateByParam = new HashMap<>();
protected Map l1ByParam = new HashMap<>();
protected Map l2ByParam = new HashMap<>();
protected LearningRatePolicy learningRatePolicy = LearningRatePolicy.None;
protected double lrPolicyDecayRate;
protected double lrPolicySteps;
protected double lrPolicyPower;
protected boolean pretrain;
// this field defines preOutput cache
protected CacheMode cacheMode;
//Counter for the number of parameter updates so far for this layer.
//Note that this is only used for pretrain layers (RBM, VAE) - MultiLayerConfiguration and ComputationGraphConfiguration
//contain counters for standard backprop training.
// This is important for learning rate schedules, for example, and is stored here to ensure it is persisted
// for Spark and model serialization
protected int iterationCount = 0;
private static ObjectMapper mapper = initMapper();
private static final ObjectMapper mapperYaml = initMapperYaml();
private static Set> subtypesClassCache = null;
/**
* Creates and returns a deep copy of the configuration.
*/
@Override
public NeuralNetConfiguration clone() {
try {
NeuralNetConfiguration clone = (NeuralNetConfiguration) super.clone();
if (clone.layer != null)
clone.layer = clone.layer.clone();
if (clone.stepFunction != null)
clone.stepFunction = clone.stepFunction.clone();
if (clone.variables != null)
clone.variables = new ArrayList<>(clone.variables);
if (clone.learningRateByParam != null)
clone.learningRateByParam = new HashMap<>(clone.learningRateByParam);
if (clone.l1ByParam != null)
clone.l1ByParam = new HashMap<>(clone.l1ByParam);
if (clone.l2ByParam != null)
clone.l2ByParam = new HashMap<>(clone.l2ByParam);
return clone;
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
}
}
public List variables() {
return new ArrayList<>(variables);
}
public List variables(boolean copy) {
if (copy)
return variables();
return variables;
}
public void addVariable(String variable) {
if (!variables.contains(variable)) {
variables.add(variable);
setLayerParamLR(variable);
}
}
public void clearVariables() {
variables.clear();
l1ByParam.clear();
l2ByParam.clear();
learningRateByParam.clear();
}
public void resetVariables() {
for (String s : variables) {
setLayerParamLR(s);
}
}
public void setLayerParamLR(String variable) {
double lr = layer.getLearningRateByParam(variable);
double l1 = layer.getL1ByParam(variable);
if (Double.isNaN(l1))
l1 = 0.0; //Not set
double l2 = layer.getL2ByParam(variable);
if (Double.isNaN(l2))
l2 = 0.0; //Not set
learningRateByParam.put(variable, lr);
l1ByParam.put(variable, l1);
l2ByParam.put(variable, l2);
}
public double getLearningRateByParam(String variable) {
return learningRateByParam.get(variable);
}
public void setLearningRateByParam(String variable, double rate) {
learningRateByParam.put(variable, rate);
}
public double getL1ByParam(String variable) {
return l1ByParam.get(variable);
}
public double getL2ByParam(String variable) {
return l2ByParam.get(variable);
}
/**
* Fluent interface for building a list of configurations
*/
public static class ListBuilder extends MultiLayerConfiguration.Builder {
private Map layerwise;
private Builder globalConfig;
// Constructor
public ListBuilder(Builder globalConfig, Map layerMap) {
this.globalConfig = globalConfig;
this.layerwise = layerMap;
}
public ListBuilder(Builder globalConfig) {
this(globalConfig, new HashMap());
}
public ListBuilder backprop(boolean backprop) {
this.backprop = backprop;
return this;
}
public ListBuilder pretrain(boolean pretrain) {
this.pretrain = pretrain;
return this;
}
public ListBuilder layer(int ind, Layer layer) {
if (layerwise.containsKey(ind)) {
layerwise.get(ind).layer(layer);
} else {
layerwise.put(ind, globalConfig.clone().layer(layer));
}
return this;
}
public Map getLayerwise() {
return layerwise;
}
/**
* Build the multi layer network
* based on this neural network and
* overr ridden parameters
*
* @return the configuration to build
*/
public MultiLayerConfiguration build() {
List list = new ArrayList<>();
if (layerwise.isEmpty())
throw new IllegalStateException("Invalid configuration: no layers defined");
for (int i = 0; i < layerwise.size(); i++) {
if (layerwise.get(i) == null) {
throw new IllegalStateException("Invalid configuration: layer number " + i
+ " not specified. Expect layer " + "numbers to be 0 to " + (layerwise.size() - 1)
+ " inclusive (number of layers defined: " + layerwise.size() + ")");
}
if (layerwise.get(i).getLayer() == null)
throw new IllegalStateException("Cannot construct network: Layer config for" + "layer with index "
+ i + " is not defined)");
//Layer names: set to default, if not set
if (layerwise.get(i).getLayer().getLayerName() == null) {
layerwise.get(i).getLayer().setLayerName("layer" + i);
}
list.add(layerwise.get(i).build());
}
return new MultiLayerConfiguration.Builder().backprop(backprop).inputPreProcessors(inputPreProcessors)
.pretrain(pretrain).backpropType(backpropType).tBPTTForwardLength(tbpttFwdLength)
.tBPTTBackwardLength(tbpttBackLength).setInputType(this.inputType)
.trainingWorkspaceMode(globalConfig.trainingWorkspaceMode).cacheMode(globalConfig.cacheMode)
.inferenceWorkspaceMode(globalConfig.inferenceWorkspaceMode).confs(list).build();
}
}
/**
* Return this configuration as json
*
* @return this configuration represented as json
*/
public String toYaml() {
ObjectMapper mapper = mapperYaml();
try {
String ret = mapper.writeValueAsString(this);
return ret;
} catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Create a neural net configuration from json
*
* @param json the neural net configuration from json
* @return
*/
public static NeuralNetConfiguration fromYaml(String json) {
ObjectMapper mapper = mapperYaml();
try {
NeuralNetConfiguration ret = mapper.readValue(json, NeuralNetConfiguration.class);
return ret;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Return this configuration as json
*
* @return this configuration represented as json
*/
public String toJson() {
ObjectMapper mapper = mapper();
try {
String ret = mapper.writeValueAsString(this);
return ret;
} catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Create a neural net configuration from json
*
* @param json the neural net configuration from json
* @return
*/
public static NeuralNetConfiguration fromJson(String json) {
ObjectMapper mapper = mapper();
try {
NeuralNetConfiguration ret = mapper.readValue(json, NeuralNetConfiguration.class);
return ret;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Object mapper for serialization of configurations
*
* @return
*/
public static ObjectMapper mapperYaml() {
return mapperYaml;
}
private static ObjectMapper initMapperYaml() {
ObjectMapper ret = new ObjectMapper(new YAMLFactory());
configureMapper(ret);
return ret;
}
/**
* Object mapper for serialization of configurations
*
* @return
*/
public static ObjectMapper mapper() {
return mapper;
}
/**
* Reinitialize and return the Jackson/json ObjectMapper with additional named types.
* This can be used to add additional subtypes at runtime (i.e., for JSON mapping with
* types defined outside of the main DL4J codebase)
*/
public static ObjectMapper reinitMapperWithSubtypes(Collection additionalTypes) {
mapper.registerSubtypes(additionalTypes.toArray(new NamedType[additionalTypes.size()]));
//Recreate the mapper (via copy), as mapper won't use registered subtypes after first use
mapper = mapper.copy();
return mapper;
}
private static ObjectMapper initMapper() {
ObjectMapper ret = new ObjectMapper();
configureMapper(ret);
return ret;
}
private static void configureMapper(ObjectMapper ret) {
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
ret.enable(SerializationFeature.INDENT_OUTPUT);
SimpleModule customDeserializerModule = new SimpleModule();
customDeserializerModule.setDeserializerModifier(new BeanDeserializerModifier() {
@Override
public JsonDeserializer> modifyDeserializer(DeserializationConfig config, BeanDescription beanDesc,
JsonDeserializer> deserializer) {
//Use our custom deserializers to handle backward compatibility for updaters -> IUpdater
if (beanDesc.getBeanClass() == MultiLayerConfiguration.class) {
return new MultiLayerConfigurationDeserializer(deserializer);
} else if (beanDesc.getBeanClass() == ComputationGraphConfiguration.class) {
return new ComputationGraphConfigurationDeserializer(deserializer);
}
return deserializer;
}
});
ret.registerModule(customDeserializerModule);
registerSubtypes(ret);
}
private static synchronized void registerSubtypes(ObjectMapper mapper) {
//Register concrete subtypes for JSON serialization
List> classes = Arrays.>asList(InputPreProcessor.class, ILossFunction.class,
IActivation.class, Layer.class, GraphVertex.class, ReconstructionDistribution.class);
List classNames = new ArrayList<>(6);
for (Class> c : classes)
classNames.add(c.getName());
// First: scan the classpath and find all instances of the 'baseClasses' classes
if (subtypesClassCache == null) {
//Check system property:
String prop = System.getProperty(CUSTOM_FUNCTIONALITY);
if (prop != null && !Boolean.parseBoolean(prop)) {
subtypesClassCache = Collections.emptySet();
} else {
List> interfaces = Arrays.>asList(InputPreProcessor.class, ILossFunction.class,
IActivation.class, ReconstructionDistribution.class);
List> classesList = Arrays.>asList(Layer.class, GraphVertex.class);
Collection urls = ClasspathHelper.forClassLoader();
List scanUrls = new ArrayList<>();
for (URL u : urls) {
String path = u.getPath();
if (!path.matches(".*/jre/lib/.*jar")) { //Skip JRE/JDK JARs
scanUrls.add(u);
}
}
Reflections reflections = new Reflections(new ConfigurationBuilder().filterInputsBy(new FilterBuilder()
.exclude("^(?!.*\\.class$).*$") //Consider only .class files (to avoid debug messages etc. on .dlls, etc
//Exclude the following: the assumption here is that no custom functionality will ever be present
// under these package name prefixes. These are all common dependencies for DL4J
.exclude("^org.nd4j.*").exclude("^org.datavec.*").exclude("^org.bytedeco.*") //JavaCPP
.exclude("^com.fasterxml.*")//Jackson
.exclude("^org.apache.*") //Apache commons, Spark, log4j etc
.exclude("^org.projectlombok.*").exclude("^com.twelvemonkeys.*").exclude("^org.joda.*")
.exclude("^org.slf4j.*").exclude("^com.google.*").exclude("^org.reflections.*")
.exclude("^ch.qos.*") //Logback
).addUrls(scanUrls).setScanners(new DL4JSubTypesScanner(interfaces, classesList)));
org.reflections.Store store = reflections.getStore();
Iterable subtypesByName = store.getAll(DL4JSubTypesScanner.class.getSimpleName(), classNames);
Set extends Class>> subtypeClasses = Sets.newHashSet(ReflectionUtils.forNames(subtypesByName));
subtypesClassCache = new HashSet<>();
for (Class> c : subtypeClasses) {
if (Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers())) {
//log.info("Skipping abstract/interface: {}",c);
continue;
}
subtypesClassCache.add(c);
}
}
}
//Second: get all currently registered subtypes for this mapper
Set> registeredSubtypes = new HashSet<>();
for (Class> c : classes) {
AnnotatedClass ac = AnnotatedClass.construct(c, mapper.getSerializationConfig().getAnnotationIntrospector(),
null);
Collection types =
mapper.getSubtypeResolver().collectAndResolveSubtypes(ac, mapper.getSerializationConfig(),
mapper.getSerializationConfig().getAnnotationIntrospector());
for (NamedType nt : types) {
registeredSubtypes.add(nt.getType());
}
}
//Third: register all _concrete_ subtypes that are not already registered
List toRegister = new ArrayList<>();
for (Class> c : subtypesClassCache) {
//Check if it's concrete or abstract...
if (Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers())) {
//log.info("Skipping abstract/interface: {}",c);
continue;
}
if (!registeredSubtypes.contains(c)) {
String name;
if (ClassUtils.isInnerClass(c)) {
Class> c2 = c.getDeclaringClass();
name = c2.getSimpleName() + "$" + c.getSimpleName();
} else {
name = c.getSimpleName();
}
toRegister.add(new NamedType(c, name));
if (log.isDebugEnabled()) {
for (Class> baseClass : classes) {
if (baseClass.isAssignableFrom(c)) {
log.debug("Registering class for JSON serialization: {} as subtype of {}", c.getName(),
baseClass.getName());
break;
}
}
}
}
}
mapper.registerSubtypes(toRegister.toArray(new NamedType[toRegister.size()]));
}
@Data
public static class Builder implements Cloneable {
protected IActivation activationFn = new ActivationSigmoid();
protected WeightInit weightInit = WeightInit.XAVIER;
protected double biasInit = 0.0;
protected Distribution dist = null;
protected double learningRate = 1e-1;
protected double biasLearningRate = Double.NaN;
protected Map learningRateSchedule = null;
protected double lrScoreBasedDecay;
protected double l1 = Double.NaN;
protected double l2 = Double.NaN;
protected double l1Bias = Double.NaN;
protected double l2Bias = Double.NaN;
protected double dropOut = 0;
@Deprecated
protected Updater updater = Updater.SGD;
protected IUpdater iUpdater = new Sgd();
@Deprecated
protected double momentum = Double.NaN;
@Deprecated
protected Map momentumSchedule = null;
@Deprecated
protected double epsilon = Double.NaN;
@Deprecated
protected double rho = Double.NaN;
@Deprecated
protected double rmsDecay = Double.NaN;
@Deprecated
protected double adamMeanDecay = Double.NaN;
@Deprecated
protected double adamVarDecay = Double.NaN;
protected Layer layer;
@Deprecated
protected double leakyreluAlpha = 0.01;
protected boolean miniBatch = true;
protected int numIterations = 1;
protected int maxNumLineSearchIterations = 5;
protected long seed = System.currentTimeMillis();
protected boolean useRegularization = false;
protected OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
protected StepFunction stepFunction = null;
protected boolean useDropConnect = false;
protected boolean minimize = true;
protected GradientNormalization gradientNormalization = GradientNormalization.None;
protected double gradientNormalizationThreshold = 1.0;
protected LearningRatePolicy learningRatePolicy = LearningRatePolicy.None;
protected double lrPolicyDecayRate = Double.NaN;
protected double lrPolicySteps = Double.NaN;
protected double lrPolicyPower = Double.NaN;
protected boolean pretrain = false;
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.NONE;
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.SEPARATE;
protected CacheMode cacheMode = CacheMode.NONE;
protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
public Builder() {
//
}
public Builder(NeuralNetConfiguration newConf) {
if (newConf != null) {
minimize = newConf.minimize;
maxNumLineSearchIterations = newConf.maxNumLineSearchIterations;
layer = newConf.layer;
numIterations = newConf.numIterations;
useRegularization = newConf.useRegularization;
optimizationAlgo = newConf.optimizationAlgo;
seed = newConf.seed;
stepFunction = newConf.stepFunction;
useDropConnect = newConf.useDropConnect;
miniBatch = newConf.miniBatch;
learningRatePolicy = newConf.learningRatePolicy;
lrPolicyDecayRate = newConf.lrPolicyDecayRate;
lrPolicySteps = newConf.lrPolicySteps;
lrPolicyPower = newConf.lrPolicyPower;
pretrain = newConf.pretrain;
}
}
/**
* Process input as minibatch vs full dataset.
* Default set to true.
*/
public Builder miniBatch(boolean miniBatch) {
this.miniBatch = miniBatch;
return this;
}
/**
* This method defines Workspace mode being used during training:
* NONE: workspace won't be used
* SINGLE: one workspace will be used during whole iteration loop
* SEPARATE: separate workspaces will be used for feedforward and backprop iteration loops
*
* @param workspaceMode
* @return
*/
public Builder trainingWorkspaceMode(@NonNull WorkspaceMode workspaceMode) {
this.trainingWorkspaceMode = workspaceMode;
return this;
}
/**
* This method defines Workspace mode being used during inference:
* NONE: workspace won't be used
* SINGLE: one workspace will be used during whole iteration loop
* SEPARATE: separate workspaces will be used for feedforward and backprop iteration loops
*
* @param workspaceMode
* @return
*/
public Builder inferenceWorkspaceMode(@NonNull WorkspaceMode workspaceMode) {
this.inferenceWorkspaceMode = workspaceMode;
return this;
}
/**
* This method defines how/if preOutput cache is handled:
* NONE: cache disabled (default value)
* HOST: Host memory will be used
* DEVICE: GPU memory will be used (on CPU backends effect will be the same as for HOST)
*
* @param cacheMode
* @return
*/
public Builder cacheMode(@NonNull CacheMode cacheMode) {
this.cacheMode = cacheMode;
return this;
}
/**
* Use drop connect: multiply the weight by a binomial sampling wrt the dropout probability.
* Dropconnect probability is set using {@link #dropOut(double)}; this is the probability of retaining a weight
*
* @param useDropConnect whether to use drop connect or not
* @return the
*/
public Builder useDropConnect(boolean useDropConnect) {
this.useDropConnect = useDropConnect;
return this;
}
/**
* Objective function to minimize or maximize cost function
* Default set to minimize true.
*/
public Builder minimize(boolean minimize) {
this.minimize = minimize;
return this;
}
/**
* Maximum number of line search iterations.
* Only applies for line search optimizers: Line Search SGD, Conjugate Gradient, LBFGS
* is NOT applicable for standard SGD
*
* @param maxNumLineSearchIterations > 0
* @return
*/
public Builder maxNumLineSearchIterations(int maxNumLineSearchIterations) {
this.maxNumLineSearchIterations = maxNumLineSearchIterations;
return this;
}
/**
* Layer class.
*/
public Builder layer(Layer layer) {
this.layer = layer;
return this;
}
/**
* Step function to apply for back track line search.
* Only applies for line search optimizers: Line Search SGD, Conjugate Gradient, LBFGS
* Options: DefaultStepFunction (default), NegativeDefaultStepFunction
* GradientStepFunction (for SGD), NegativeGradientStepFunction
*/
public Builder stepFunction(StepFunction stepFunction) {
this.stepFunction = stepFunction;
return this;
}
/**
* Create a ListBuilder (for creating a MultiLayerConfiguration)
* Usage:
*
* {@code .list()
* .layer(0,new DenseLayer.Builder()...build())
* ...
* .layer(n,new OutputLayer.Builder()...build())
* }
*
*/
public ListBuilder list() {
return new ListBuilder(this);
}
/**
* Create a ListBuilder (for creating a MultiLayerConfiguration) with the specified layers
* Usage:
*
* {@code .list(
* new DenseLayer.Builder()...build(),
* ...,
* new OutputLayer.Builder()...build())
* }
*
*
* @param layers The layer configurations for the network
*/
public ListBuilder list(Layer... layers) {
if (layers == null || layers.length == 0)
throw new IllegalArgumentException("Cannot create network with no layers");
Map layerMap = new HashMap<>();
for (int i = 0; i < layers.length; i++) {
Builder b = this.clone();
b.layer(layers[i]);
layerMap.put(i, b);
}
return new ListBuilder(this, layerMap);
}
/**
* Create a GraphBuilder (for creating a ComputationGraphConfiguration).
*/
public ComputationGraphConfiguration.GraphBuilder graphBuilder() {
return new ComputationGraphConfiguration.GraphBuilder(this);
}
/**
* Number of optimization iterations.
*/
public Builder iterations(int numIterations) {
this.numIterations = numIterations;
return this;
}
/**
* Random number generator seed. Used for reproducability between runs
*/
public Builder seed(int seed) {
this.seed = (long) seed;
Nd4j.getRandom().setSeed(seed);
return this;
}
/**
* Random number generator seed. Used for reproducability between runs
*/
public Builder seed(long seed) {
this.seed = seed;
Nd4j.getRandom().setSeed(seed);
return this;
}
/**
* Optimization algorithm to use. Most common: OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT
*
* @param optimizationAlgo Optimization algorithm to use when training
*/
public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) {
this.optimizationAlgo = optimizationAlgo;
return this;
}
/**
* Whether to use regularization (l1, l2, dropout, etc
*/
public Builder regularization(boolean useRegularization) {
this.useRegularization = useRegularization;
return this;
}
@Override
public Builder clone() {
try {
Builder clone = (Builder) super.clone();
if (clone.layer != null)
clone.layer = clone.layer.clone();
if (clone.stepFunction != null)
clone.stepFunction = clone.stepFunction.clone();
return clone;
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
}
}
/**
* Activation function / neuron non-linearity
* Typical values include:
* "relu" (rectified linear), "tanh", "sigmoid", "softmax",
* "hardtanh", "leakyrelu", "maxout", "softsign", "softplus"
*
* @deprecated Use {@link #activation(Activation)} or
* {@link @activation(IActivation)}
*/
@Deprecated
public Builder activation(String activationFunction) {
return activation(Activation.fromString(activationFunction).getActivationFunction());
}
/**
* Activation function / neuron non-linearity
*
* @see #activation(Activation)
*/
public Builder activation(IActivation activationFunction) {
this.activationFn = activationFunction;
return this;
}
/**
* Activation function / neuron non-linearity
*/
public Builder activation(Activation activation) {
return activation(activation.getActivationFunction());
}
/**
* @deprecated Use {@link #activation(IActivation)} with leaky relu, setting alpha value directly in constructor.
*/
@Deprecated
public Builder leakyreluAlpha(double leakyreluAlpha) {
this.leakyreluAlpha = leakyreluAlpha;
return this;
}
/**
* Weight initialization scheme.
*
* @see org.deeplearning4j.nn.weights.WeightInit
*/
public Builder weightInit(WeightInit weightInit) {
this.weightInit = weightInit;
return this;
}
/**
* Constant for bias initialization. Default: 0.0
*
* @param biasInit Constant for bias initialization
*/
public Builder biasInit(double biasInit) {
this.biasInit = biasInit;
return this;
}
/**
* Distribution to sample initial weights from. Used in conjunction with
* .weightInit(WeightInit.DISTRIBUTION).
*/
public Builder dist(Distribution dist) {
this.dist = dist;
return this;
}
/**
* Learning rate. Defaults to 1e-1
*/
public Builder learningRate(double learningRate) {
this.learningRate = learningRate;
return this;
}
/**
* Bias learning rate. Set this to apply a different learning rate to the bias
*/
public Builder biasLearningRate(double biasLearningRate) {
this.biasLearningRate = biasLearningRate;
return this;
}
/**
* Learning rate schedule. Map of the iteration to the learning rate to apply at that iteration.
*/
public Builder learningRateSchedule(Map learningRateSchedule) {
this.learningRateSchedule = learningRateSchedule;
return this;
}
/**
* Rate to decrease learningRate by when the score stops improving.
* Learning rate is multiplied by this rate so ideally keep between 0 and 1.
*/
public Builder learningRateScoreBasedDecayRate(double lrScoreBasedDecay) {
this.lrScoreBasedDecay = lrScoreBasedDecay;
return this;
}
/**
* L1 regularization coefficient for the weights.
* Use with .regularization(true)
*/
public Builder l1(double l1) {
this.l1 = l1;
return this;
}
/**
* L2 regularization coefficient for the weights.
* Use with .regularization(true)
*/
public Builder l2(double l2) {
this.l2 = l2;
return this;
}
/**
* L1 regularization coefficient for the bias.
* Use with .regularization(true)
*/
public Builder l1Bias(double l1Bias) {
this.l1Bias = l1Bias;
return this;
}
/**
* L2 regularization coefficient for the bias.
* Use with .regularization(true)
*/
public Builder l2Bias(double l2Bias) {
this.l2Bias = l2Bias;
return this;
}
/**
* Dropout probability. This is the probability of retaining an activation. So dropOut(x) will keep an
* activation with probability x, and set to 0 with probability 1-x.
* dropOut(0.0) is disabled (default).
*
* Note: This sets the probability per-layer. Care should be taken when setting lower values for complex networks.
*
*
* @param dropOut Dropout probability (probability of retaining an activation)
*/
public Builder dropOut(double dropOut) {
this.dropOut = dropOut;
return this;
}
/**
* Momentum rate
* Used only when Updater is set to {@link Updater#NESTEROVS}
*
* @deprecated Use {@code .updater(new Nesterov(momentum))} instead
*/
@Deprecated
public Builder momentum(double momentum) {
this.momentum = momentum;
return this;
}
/**
* Momentum schedule. Map of the iteration to the momentum rate to apply at that iteration
* Used only when Updater is set to {@link Updater#NESTEROVS}
*
* @deprecated Use {@code .updater(Nesterov.builder().momentumSchedule(schedule).build())} instead
*/
@Deprecated
public Builder momentumAfter(Map momentumAfter) {
this.momentumSchedule = momentumAfter;
return this;
}
/**
* Gradient updater. For example, Updater.SGD for standard stochastic gradient descent,
* Updater.NESTEROV for Nesterov momentum, Updater.RSMPROP for RMSProp, etc.
* Note: default hyperparameters are used with this method. Use {@link #updater(IUpdater)} to configure
* the updater-specific hyperparameters.
*
* @see Updater
*/
public Builder updater(Updater updater) {
this.updater = updater;
return updater(updater.getIUpdaterWithDefaultConfig());
}
/**
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam}
* or {@link org.nd4j.linalg.learning.config.Nesterovs}
*
* @param updater Updater to use
*/
public Builder updater(IUpdater updater) {
//Ensure legacy field is set...
if(updater instanceof Sgd) this.updater = Updater.SGD;
else if(updater instanceof Adam) this.updater = Updater.ADAM;
else if(updater instanceof AdaMax) this.updater = Updater.ADAMAX;
else if(updater instanceof AdaDelta) this.updater = Updater.ADADELTA;
else if(updater instanceof Nesterovs) this.updater = Updater.NESTEROVS;
else if(updater instanceof Nadam) this.updater = Updater.NADAM;
else if(updater instanceof AdaGrad) this.updater = Updater.ADAGRAD;
else if(updater instanceof RmsProp) this.updater = Updater.RMSPROP;
else if(updater instanceof NoOp) this.updater = Updater.NONE;
this.iUpdater = updater;
return this;
}
/**
* Ada delta coefficient
*
* @param rho
* @deprecated use {@code .updater(new AdaDelta(rho,epsilon))} intead
*/
@Deprecated
public Builder rho(double rho) {
this.rho = rho;
return this;
}
/**
* Epsilon value for updaters: Adam, RMSProp, Adagrad, Adadelta
*
* @param epsilon Epsilon value to use for adagrad or
* @deprecated Use use {@code .updater(Adam.builder().epsilon(epsilon).build())} or similar instead
*/
@Deprecated
public Builder epsilon(double epsilon) {
this.epsilon = epsilon;
return this;
}
/**
* Decay rate for RMSProp. Only applies if using .updater(Updater.RMSPROP)
*
* @deprecated use {@code .updater(new RmsProp(rmsDecay))} intead
*/
@Deprecated
public Builder rmsDecay(double rmsDecay) {
this.rmsDecay = rmsDecay;
return this;
}
/**
* Mean decay rate for Adam updater. Only applies if using .updater(Updater.ADAM)
*
* @deprecated use {@code .updater(Adam.builder().beta1(adamMeanDecay).build())} intead
*/
@Deprecated
public Builder adamMeanDecay(double adamMeanDecay) {
this.adamMeanDecay = adamMeanDecay;
return this;
}
/**
* Variance decay rate for Adam updater. Only applies if using .updater(Updater.ADAM)
*
* @deprecated use {@code .updater(Adam.builder().beta2(adamVarDecay).build())} intead
*/
@Deprecated
public Builder adamVarDecay(double adamVarDecay) {
this.adamVarDecay = adamVarDecay;
return this;
}
/**
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc.
*
* @param gradientNormalization Type of normalization to use. Defaults to None.
* @see GradientNormalization
*/
public Builder gradientNormalization(GradientNormalization gradientNormalization) {
this.gradientNormalization = gradientNormalization;
return this;
}
/**
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
* GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue
* Not used otherwise.
* L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping.
*/
public Builder gradientNormalizationThreshold(double threshold) {
this.gradientNormalizationThreshold = threshold;
return this;
}
/**
* Learning rate decay policy. Used to adapt learning rate based on policy.
*
* @param policy Type of policy to use. Defaults to None.
*/
public Builder learningRateDecayPolicy(LearningRatePolicy policy) {
this.learningRatePolicy = policy;
return this;
}
/**
* Set the decay rate for the learning rate decay policy.
*
* @param lrPolicyDecayRate rate.
*/
public Builder lrPolicyDecayRate(double lrPolicyDecayRate) {
this.lrPolicyDecayRate = lrPolicyDecayRate;
return this;
}
/**
* Set the number of steps used for learning decay rate steps policy.
*
* @param lrPolicySteps number of steps
*/
public Builder lrPolicySteps(double lrPolicySteps) {
this.lrPolicySteps = lrPolicySteps;
return this;
}
/**
* Set the power used for learning rate inverse policy.
*
* @param lrPolicyPower power
*/
public Builder lrPolicyPower(double lrPolicyPower) {
this.lrPolicyPower = lrPolicyPower;
return this;
}
public Builder convolutionMode(ConvolutionMode convolutionMode) {
this.convolutionMode = convolutionMode;
return this;
}
private void learningRateValidation(String layerName) {
if (learningRatePolicy != LearningRatePolicy.None && Double.isNaN(lrPolicyDecayRate)) {
//LR policy, if used, should have a decay rate. 2 exceptions: Map for schedule, and Poly + power param
if (!(learningRatePolicy == LearningRatePolicy.Schedule && learningRateSchedule != null)
&& !(learningRatePolicy == LearningRatePolicy.Poly && !Double.isNaN(lrPolicyPower)))
throw new IllegalStateException("Layer \"" + layerName
+ "\" learning rate policy decay rate (lrPolicyDecayRate) must be set to use learningRatePolicy.");
}
switch (learningRatePolicy) {
case Inverse:
case Poly:
if (Double.isNaN(lrPolicyPower))
throw new IllegalStateException("Layer \"" + layerName
+ "\" learning rate policy power (lrPolicyPower) must be set to use "
+ learningRatePolicy);
break;
case Step:
case Sigmoid:
if (Double.isNaN(lrPolicySteps))
throw new IllegalStateException("Layer \"" + layerName
+ "\" learning rate policy steps (lrPolicySteps) must be set to use "
+ learningRatePolicy);
break;
case Schedule:
if (learningRateSchedule == null)
throw new IllegalStateException("Layer \"" + layerName
+ "\" learning rate policy schedule (learningRateSchedule) must be set to use "
+ learningRatePolicy);
break;
}
if (!Double.isNaN(lrPolicyPower) && (learningRatePolicy != LearningRatePolicy.Inverse
&& learningRatePolicy != LearningRatePolicy.Poly))
throw new IllegalStateException("Layer \"" + layerName
+ "\" power has been set but will not be applied unless the learning rate policy is set to Inverse or Poly.");
if (!Double.isNaN(lrPolicySteps) && (learningRatePolicy != LearningRatePolicy.Step
&& learningRatePolicy != LearningRatePolicy.Sigmoid
&& learningRatePolicy != LearningRatePolicy.TorchStep))
throw new IllegalStateException("Layer \"" + layerName
+ "\" steps have been set but will not be applied unless the learning rate policy is set to Step or Sigmoid.");
if ((learningRateSchedule != null) && (learningRatePolicy != LearningRatePolicy.Schedule))
throw new IllegalStateException("Layer \"" + layerName
+ "\" learning rate schedule has been set but will not be applied unless the learning rate policy is set to Schedule.");
}
////////////////
/**
* Return a configuration based on this builder
*
* @return
*/
public NeuralNetConfiguration build() {
NeuralNetConfiguration conf = new NeuralNetConfiguration();
conf.minimize = minimize;
conf.maxNumLineSearchIterations = maxNumLineSearchIterations;
conf.layer = layer;
conf.numIterations = numIterations;
conf.useRegularization = useRegularization;
conf.optimizationAlgo = optimizationAlgo;
conf.seed = seed;
conf.stepFunction = stepFunction;
conf.useDropConnect = useDropConnect;
conf.miniBatch = miniBatch;
conf.learningRatePolicy = learningRatePolicy;
conf.lrPolicyDecayRate = lrPolicyDecayRate;
conf.lrPolicySteps = lrPolicySteps;
conf.lrPolicyPower = lrPolicyPower;
conf.pretrain = pretrain;
conf.cacheMode = this.cacheMode;
configureLayer(layer);
if (layer instanceof FrozenLayer) {
configureLayer(((FrozenLayer) layer).getLayer());
}
return conf;
}
private void configureLayer(Layer layer) {
String layerName;
if (layer == null || layer.getLayerName() == null)
layerName = "Layer not named";
else
layerName = layer.getLayerName();
learningRateValidation(layerName);
if (layer != null) {
copyConfigToLayer(layerName, layer);
}
if (layer instanceof FrozenLayer) {
copyConfigToLayer(layerName, ((FrozenLayer) layer).getLayer());
}
if (layer instanceof ConvolutionLayer) {
ConvolutionLayer cl = (ConvolutionLayer) layer;
if (cl.getConvolutionMode() == null) {
cl.setConvolutionMode(convolutionMode);
}
}
if (layer instanceof SubsamplingLayer) {
SubsamplingLayer sl = (SubsamplingLayer) layer;
if (sl.getConvolutionMode() == null) {
sl.setConvolutionMode(convolutionMode);
}
}
LayerValidation.generalValidation(layerName, layer, useRegularization, useDropConnect, dropOut, l2, l2Bias,
l1, l1Bias, dist);
}
private void copyConfigToLayer(String layerName, Layer layer) {
if (Double.isNaN(layer.getDropOut()))
layer.setDropOut(dropOut);
if (layer instanceof BaseLayer) {
BaseLayer bLayer = (BaseLayer) layer;
if (Double.isNaN(bLayer.getLearningRate()))
bLayer.setLearningRate(learningRate);
if (Double.isNaN(bLayer.getBiasLearningRate())) {
//Two possibilities when bias LR isn't set for layer:
// (a) If global bias LR *is* set -> set it to that
// (b) Otherwise, set to layer LR (and, by extension, the global LR)
if (!Double.isNaN(biasLearningRate)) {
//Global bias LR is set
bLayer.setBiasLearningRate(biasLearningRate);
} else {
bLayer.setBiasLearningRate(bLayer.getLearningRate());
}
}
if (bLayer.getLearningRateSchedule() == null)
bLayer.setLearningRateSchedule(learningRateSchedule);
if (Double.isNaN(bLayer.getL1()))
bLayer.setL1(l1);
if (Double.isNaN(bLayer.getL2()))
bLayer.setL2(l2);
if (bLayer.getActivationFn() == null)
bLayer.setActivationFn(activationFn);
if (bLayer.getWeightInit() == null)
bLayer.setWeightInit(weightInit);
if (Double.isNaN(bLayer.getBiasInit()))
bLayer.setBiasInit(biasInit);
if (bLayer.getUpdater() == null)
bLayer.setUpdater(updater);
if (bLayer.getIUpdater() == null) {
bLayer.setIUpdater(iUpdater.clone());
}
LayerValidation.updaterValidation(layerName, layer, learningRate, momentum, momentumSchedule,
adamMeanDecay, adamVarDecay, rho, rmsDecay, epsilon);
if (bLayer.getGradientNormalization() == null)
bLayer.setGradientNormalization(gradientNormalization);
if (Double.isNaN(bLayer.getGradientNormalizationThreshold()))
bLayer.setGradientNormalizationThreshold(gradientNormalizationThreshold);
}
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy