
org.deeplearning4j.arbiter.MultiLayerSpace Maven / Gradle / Ivy
package org.deeplearning4j.arbiter;
import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.CollectionUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import java.util.ArrayList;
import java.util.List;
//public class MultiLayerSpace implements ModelParameterSpace {
public class MultiLayerSpace extends BaseNetworkSpace {
private ParameterSpace cnnInputSize;
private List layerSpaces = new ArrayList<>();
//Early stopping configuration / (fixed) number of epochs:
private EarlyStoppingConfiguration earlyStoppingConfiguration;
private int numParameters;
private MultiLayerSpace(Builder builder){
super(builder);
this.cnnInputSize = builder.cnnInputSize;
this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;
this.layerSpaces = builder.layerSpaces;
//Determine total number of parameters:
List list = CollectionUtils.getUnique(collectLeaves());
for(ParameterSpace ps : list) numParameters += ps.numParameters();
//TODO inputs
}
@Override
public DL4JConfiguration getValue(double[] values) {
//First: create layer configs
List layers = new ArrayList<>();
for(LayerConf c : layerSpaces){
int n = c.numLayers.getValue(values);
if(c.duplicateConfig){
//Generate N identical configs
org.deeplearning4j.nn.conf.layers.Layer l = c.layerSpace.getValue(values);
for( int i=0; i collectLeaves() {
List list = super.collectLeaves();
for(LayerConf lc : layerSpaces){
list.addAll(lc.numLayers.collectLeaves());
list.addAll(lc.layerSpace.collectLeaves());
}
if(cnnInputSize != null) list.addAll(cnnInputSize.collectLeaves());
return list;
}
@Override
public String toString(){
StringBuilder sb = new StringBuilder(super.toString());
int i=0;
for(LayerConf conf : layerSpaces){
sb.append("Layer config ").append(i++).append(": (Number layers:").append(conf.numLayers)
.append(", duplicate: ").append(conf.duplicateConfig).append("), ")
.append(conf.layerSpace.toString()).append("\n");
}
if(earlyStoppingConfiguration != null){
sb.append("Early stopping configuration:").append(earlyStoppingConfiguration.toString()).append("\n");
} else {
sb.append("Training # epochs:").append(numEpochs).append("\n");
}
return sb.toString();
}
@AllArgsConstructor
private static class LayerConf {
private final LayerSpace> layerSpace;
private final ParameterSpace numLayers;
private final boolean duplicateConfig;
}
public static class Builder extends BaseNetworkSpace.Builder {
private ParameterSpace cnnInputSize;
private List layerSpaces = new ArrayList<>();
//Early stopping configuration
private EarlyStoppingConfiguration earlyStoppingConfiguration;
public Builder cnnInputSize(int height, int width, int depth){
return cnnInputSize(new FixedValue<>(new int[]{height, width, depth}));
}
public Builder cnnInputSize(ParameterSpace cnnInputSize){
this.cnnInputSize = cnnInputSize;
return this;
}
public Builder addLayer(LayerSpace> layerSpace){
return addLayer(layerSpace,new FixedValue<>(1),true);
}
/**
* @param layerSpace
* @param numLayersDistribution Distribution for number of layers to generate
* @param duplicateConfig Only used if more than 1 layer can be generated. If true: generate N identical (stacked) layers.
* If false: generate N independent layers
*/
public Builder addLayer(LayerSpace extends org.deeplearning4j.nn.conf.layers.Layer> layerSpace,
ParameterSpace numLayersDistribution, boolean duplicateConfig){
layerSpaces.add(new LayerConf(layerSpace,numLayersDistribution,duplicateConfig));
return this;
}
/** Early stopping configuration (optional). Note if both EarlyStoppingConfiguration and number of epochs is
* present, early stopping will be used in preference.
*/
public Builder earlyStoppingConfiguration(EarlyStoppingConfiguration earlyStoppingConfiguration){
this.earlyStoppingConfiguration = earlyStoppingConfiguration;
return this;
}
@SuppressWarnings("unchecked")
public MultiLayerSpace build(){
return new MultiLayerSpace(this);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy