org.deeplearning4j.arbiter.BaseNetworkSpace Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.adapter.ActivationParameterSpaceAdapter;
import org.deeplearning4j.arbiter.conf.dropout.DropoutSpace;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
/**
* This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace)
*
* Functionality here should match {@link org.deeplearning4j.nn.conf.NeuralNetConfiguration.Builder}
*
* @param Type of network (MultiLayerNetwork or ComputationGraph)
* @author Alex Black
*/
@EqualsAndHashCode(callSuper = false)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@Data
public abstract class BaseNetworkSpace extends AbstractParameterSpace {
protected Long seed;
protected ParameterSpace optimizationAlgo;
protected ParameterSpace activationFunction;
protected ParameterSpace biasInit;
protected ParameterSpace weightInit;
protected ParameterSpace dist;
protected ParameterSpace maxNumLineSearchIterations;
protected ParameterSpace miniBatch;
protected ParameterSpace minimize;
protected ParameterSpace stepFunction;
protected ParameterSpace l1;
protected ParameterSpace l2;
protected ParameterSpace l1Bias;
protected ParameterSpace l2Bias;
protected ParameterSpace updater;
protected ParameterSpace biasUpdater;
protected ParameterSpace weightNoise;
private ParameterSpace dropout;
protected ParameterSpace gradientNormalization;
protected ParameterSpace gradientNormalizationThreshold;
protected ParameterSpace convolutionMode;
protected List layerSpaces = new ArrayList<>();
//NeuralNetConfiguration.ListBuilder/MultiLayerConfiguration.Builder options:
protected ParameterSpace backprop;
protected ParameterSpace pretrain;
protected ParameterSpace backpropType;
protected ParameterSpace tbpttFwdLength;
protected ParameterSpace tbpttBwdLength;
protected ParameterSpace> allParamConstraints;
protected ParameterSpace> weightConstraints;
protected ParameterSpace> biasConstraints;
protected int numEpochs = 1;
static {
JsonMapper.getMapper().registerSubtypes(ComputationGraphSpace.class, MultiLayerSpace.class);
YamlMapper.getMapper().registerSubtypes(ComputationGraphSpace.class, MultiLayerSpace.class);
}
@SuppressWarnings("unchecked")
protected BaseNetworkSpace(Builder builder) {
this.seed = builder.seed;
this.optimizationAlgo = builder.optimizationAlgo;
this.activationFunction = builder.activationFunction;
this.biasInit = builder.biasInit;
this.weightInit = builder.weightInit;
this.dist = builder.dist;
this.maxNumLineSearchIterations = builder.maxNumLineSearchIterations;
this.miniBatch = builder.miniBatch;
this.minimize = builder.minimize;
this.stepFunction = builder.stepFunction;
this.l1 = builder.l1;
this.l2 = builder.l2;
this.l1Bias = builder.l1Bias;
this.l2Bias = builder.l2Bias;
this.updater = builder.updater;
this.biasUpdater = builder.biasUpdater;
this.weightNoise = builder.weightNoise;
this.dropout = builder.dropout;
this.gradientNormalization = builder.gradientNormalization;
this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
this.convolutionMode = builder.convolutionMode;
this.allParamConstraints = builder.allParamConstraints;
this.weightConstraints = builder.weightConstraints;
this.biasConstraints = builder.biasConstraints;
this.backprop = builder.backprop;
this.pretrain = builder.pretrain;
this.backpropType = builder.backpropType;
this.tbpttFwdLength = builder.tbpttFwdLength;
this.tbpttBwdLength = builder.tbpttBwdLength;
this.numEpochs = builder.numEpochs;
}
protected BaseNetworkSpace() {
//Default constructor for Jackson json/yaml serialization
}
protected NeuralNetConfiguration.Builder randomGlobalConf(double[] values) {
//Create MultiLayerConfiguration...
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
if (seed != null)
builder.seed(seed);
if (optimizationAlgo != null)
builder.optimizationAlgo(optimizationAlgo.getValue(values));
if (activationFunction != null)
builder.activation(activationFunction.getValue(values));
if (biasInit != null)
builder.biasInit(biasInit.getValue(values));
if (weightInit != null)
builder.weightInit(weightInit.getValue(values));
if (dist != null)
builder.dist(dist.getValue(values));
if (maxNumLineSearchIterations != null)
builder.maxNumLineSearchIterations(maxNumLineSearchIterations.getValue(values));
if (miniBatch != null)
builder.miniBatch(miniBatch.getValue(values));
if (minimize != null)
builder.minimize(minimize.getValue(values));
if (stepFunction != null)
builder.stepFunction(stepFunction.getValue(values));
if (l1 != null)
builder.l1(l1.getValue(values));
if (l2 != null)
builder.l2(l2.getValue(values));
if (l1Bias != null)
builder.l1Bias(l1Bias.getValue(values));
if (l2Bias != null)
builder.l2Bias(l2Bias.getValue(values));
if (updater != null)
builder.updater(updater.getValue(values));
if (biasUpdater != null)
builder.biasUpdater(biasUpdater.getValue(values));
if (weightNoise != null)
builder.weightNoise(weightNoise.getValue(values));
if (dropout != null)
builder.dropOut(dropout.getValue(values));
if (gradientNormalization != null)
builder.gradientNormalization(gradientNormalization.getValue(values));
if (gradientNormalizationThreshold != null)
builder.gradientNormalizationThreshold(gradientNormalizationThreshold.getValue(values));
if (convolutionMode != null)
builder.convolutionMode(convolutionMode.getValue(values));
if (allParamConstraints != null){
List c = allParamConstraints.getValue(values);
if(c != null){
builder.constrainAllParameters(c.toArray(new LayerConstraint[c.size()]));
}
}
if (weightConstraints != null){
List c = weightConstraints.getValue(values);
if(c != null){
builder.constrainWeights(c.toArray(new LayerConstraint[c.size()]));
}
}
if (biasConstraints != null){
List c = biasConstraints.getValue(values);
if(c != null){
builder.constrainBias(c.toArray(new LayerConstraint[c.size()]));
}
}
return builder;
}
@Override
public List collectLeaves() {
Map global = getNestedSpaces();
List list = new ArrayList<>();
list.addAll(global.values());
//Note: Results on previous line does NOT include the LayerSpaces, therefore we need to add these manually...
//This is because the type is a list, not a ParameterSpace
for (LayerConf layerConf : layerSpaces) {
LayerSpace ls = layerConf.getLayerSpace();
list.addAll(ls.collectLeaves());
}
return list;
}
@Override
public boolean isLeaf() {
return false;
}
@Override
public void setIndices(int... indices) {
throw new UnsupportedOperationException("Cannot set indices for non leaf");
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (Map.Entry e : getNestedSpaces().entrySet()) {
sb.append(e.getKey()).append(": ").append(e.getValue()).append("\n");
}
int i = 0;
for (LayerConf conf : layerSpaces) {
sb.append("Layer config ").append(i++).append(": (Number layers:").append(conf.numLayers)
.append(", duplicate: ").append(conf.duplicateConfig).append("), ")
.append(conf.layerSpace.toString()).append("\n");
}
return sb.toString();
}
@AllArgsConstructor
@Data
@NoArgsConstructor
public static class LayerConf {
protected LayerSpace> layerSpace;
protected String layerName;
protected String[] inputs;
protected ParameterSpace numLayers;
protected boolean duplicateConfig;
protected InputPreProcessor preProcessor;
}
@SuppressWarnings("unchecked")
protected abstract static class Builder> {
private Long seed;
private ParameterSpace optimizationAlgo;
private ParameterSpace activationFunction;
private ParameterSpace biasInit;
private ParameterSpace weightInit;
private ParameterSpace dist;
private ParameterSpace maxNumLineSearchIterations;
private ParameterSpace miniBatch;
private ParameterSpace minimize;
private ParameterSpace stepFunction;
private ParameterSpace l1;
private ParameterSpace l2;
private ParameterSpace l1Bias;
private ParameterSpace l2Bias;
private ParameterSpace updater;
private ParameterSpace biasUpdater;
private ParameterSpace weightNoise;
private ParameterSpace dropout;
private ParameterSpace gradientNormalization;
private ParameterSpace gradientNormalizationThreshold;
private ParameterSpace convolutionMode;
private ParameterSpace> allParamConstraints;
private ParameterSpace> weightConstraints;
private ParameterSpace> biasConstraints;
//NeuralNetConfiguration.ListBuilder/MultiLayerConfiguration.Builder options:
private ParameterSpace backprop;
private ParameterSpace pretrain;
private ParameterSpace backpropType;
private ParameterSpace tbpttFwdLength;
private ParameterSpace tbpttBwdLength;
//Early stopping configuration / (fixed) number of epochs:
private EarlyStoppingConfiguration earlyStoppingConfiguration;
private int numEpochs = 1;
public T seed(long seed) {
this.seed = seed;
return (T) this;
}
public T optimizationAlgo(OptimizationAlgorithm optimizationAlgorithm) {
return optimizationAlgo(new FixedValue<>(optimizationAlgorithm));
}
public T optimizationAlgo(ParameterSpace parameterSpace) {
this.optimizationAlgo = parameterSpace;
return (T) this;
}
public T activation(Activation activationFunction) {
return activation(new FixedValue<>(activationFunction));
}
public T activation(ParameterSpace activationFunction) {
return activationFn(new ActivationParameterSpaceAdapter(activationFunction));
}
public T activationFn(ParameterSpace activationFunction) {
this.activationFunction = activationFunction;
return (T) this;
}
public T biasInit(double biasInit){
return biasInit(new FixedValue<>(biasInit));
}
public T biasInit(ParameterSpace biasInit){
this.biasInit = biasInit;
return (T) this;
}
public T weightInit(WeightInit weightInit) {
return weightInit(new FixedValue<>(weightInit));
}
public T weightInit(ParameterSpace weightInit) {
this.weightInit = weightInit;
return (T) this;
}
public T dist(Distribution dist) {
return dist(new FixedValue<>(dist));
}
public T dist(ParameterSpace dist) {
this.dist = dist;
return (T) this;
}
public T maxNumLineSearchIterations(int maxNumLineSearchIterations) {
return maxNumLineSearchIterations(new FixedValue<>(maxNumLineSearchIterations));
}
public T maxNumLineSearchIterations(ParameterSpace maxNumLineSearchIterations) {
this.maxNumLineSearchIterations = maxNumLineSearchIterations;
return (T) this;
}
public T miniBatch(boolean minibatch) {
return miniBatch(new FixedValue<>(minibatch));
}
public T miniBatch(ParameterSpace miniBatch) {
this.miniBatch = miniBatch;
return (T) this;
}
public T minimize(boolean minimize) {
return minimize(new FixedValue<>(minimize));
}
public T minimize(ParameterSpace minimize) {
this.minimize = minimize;
return (T) this;
}
public T stepFunction(StepFunction stepFunction) {
return stepFunction(new FixedValue<>(stepFunction));
}
public T stepFunction(ParameterSpace stepFunction) {
this.stepFunction = stepFunction;
return (T) this;
}
public T l1(double l1) {
return l1(new FixedValue<>(l1));
}
public T l1(ParameterSpace l1) {
this.l1 = l1;
return (T) this;
}
public T l2(double l2) {
return l2(new FixedValue<>(l2));
}
public T l2(ParameterSpace l2) {
this.l2 = l2;
return (T) this;
}
public T l1Bias(double l1Bias) {
return l1Bias(new FixedValue<>(l1Bias));
}
public T l1Bias(ParameterSpace l1Bias) {
this.l1Bias = l1Bias;
return (T) this;
}
public T l2Bias(double l2Bias) {
return l2Bias(new FixedValue<>(l2Bias));
}
public T l2Bias(ParameterSpace l2Bias) {
this.l2Bias = l2Bias;
return (T) this;
}
public T updater(IUpdater updater){
return updater(new FixedValue<>(updater));
}
public T updater(ParameterSpace updater) {
this.updater = updater;
return (T) this;
}
public T biasUpdater(IUpdater biasUpdater){
return biasUpdater(new FixedValue<>(biasUpdater));
}
public T biasUpdater(ParameterSpace biasUpdater){
this.biasUpdater = biasUpdater;
return (T)this;
}
public T weightNoise(IWeightNoise weightNoise){
return weightNoise(new FixedValue<>(weightNoise));
}
public T weightNoise(ParameterSpace weightNoise){
this.weightNoise = weightNoise;
return (T) this;
}
public T dropOut(double dropout){
return idropOut(new Dropout(dropout));
}
public T dropOut(ParameterSpace dropOut){
return idropOut(new DropoutSpace(dropOut));
}
public T idropOut(IDropout idropOut){
return idropOut(new FixedValue<>(idropOut));
}
public T idropOut(ParameterSpace idropOut){
this.dropout = idropOut;
return (T) this;
}
public T gradientNormalization(GradientNormalization gradientNormalization) {
return gradientNormalization(new FixedValue<>(gradientNormalization));
}
public T gradientNormalization(ParameterSpace gradientNormalization) {
this.gradientNormalization = gradientNormalization;
return (T) this;
}
public T gradientNormalizationThreshold(double threshold) {
return gradientNormalizationThreshold(new FixedValue<>(threshold));
}
public T gradientNormalizationThreshold(ParameterSpace gradientNormalizationThreshold) {
this.gradientNormalizationThreshold = gradientNormalizationThreshold;
return (T) this;
}
public T convolutionMode(ConvolutionMode convolutionMode) {
return convolutionMode(new FixedValue(convolutionMode));
}
public T convolutionMode(ParameterSpace convolutionMode) {
this.convolutionMode = convolutionMode;
return (T) this;
}
public T backprop(boolean backprop) {
return backprop(new FixedValue<>(backprop));
}
public T backprop(ParameterSpace backprop) {
this.backprop = backprop;
return (T) this;
}
public T pretrain(boolean pretrain) {
return pretrain(new FixedValue<>(pretrain));
}
public T pretrain(ParameterSpace pretrain) {
this.pretrain = pretrain;
return (T) this;
}
public T backpropType(BackpropType backpropType) {
return backpropType(new FixedValue<>(backpropType));
}
public T backpropType(ParameterSpace backpropType) {
this.backpropType = backpropType;
return (T) this;
}
public T tbpttFwdLength(int tbpttFwdLength) {
return tbpttFwdLength(new FixedValue<>(tbpttFwdLength));
}
public T tbpttFwdLength(ParameterSpace tbpttFwdLength) {
this.tbpttFwdLength = tbpttFwdLength;
return (T) this;
}
public T tbpttBwdLength(int tbpttBwdLength) {
return tbpttBwdLength(new FixedValue<>(tbpttBwdLength));
}
public T tbpttBwdLength(ParameterSpace tbpttBwdLength) {
this.tbpttBwdLength = tbpttBwdLength;
return (T) this;
}
public T constrainWeights(LayerConstraint... constraints){
return constrainWeights(new FixedValue>(Arrays.asList(constraints)));
}
public T constrainWeights(ParameterSpace> constraints){
this.weightConstraints = constraints;
return (T) this;
}
public T constrainBias(LayerConstraint... constraints){
return constrainBias(new FixedValue>(Arrays.asList(constraints)));
}
public T constrainBias(ParameterSpace> constraints){
this.biasConstraints = constraints;
return (T) this;
}
public T constrainAllParams(LayerConstraint... constraints){
return constrainAllParams(new FixedValue>(Arrays.asList(constraints)));
}
public T constrainAllParams(ParameterSpace> constraints){
this.allParamConstraints = constraints;
return (T) this;
}
/**
* Fixed number of training epochs. Default: 1
* Note if both EarlyStoppingConfiguration and number of epochs is present, early stopping will be used in preference.
*/
public T numEpochs(int numEpochs) {
this.numEpochs = numEpochs;
return (T) this;
}
public abstract E build();
}
/**
* Return a json configuration of this configuration space.
*
* @return
*/
public String toJson() {
try {
return JsonMapper.getMapper().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Return a yaml configuration of this configuration space.
*
* @return
*/
public String toYaml() {
try {
return YamlMapper.getMapper().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}