org.deeplearning4j.arbiter.ComputationGraphSpace Maven / Gradle / Ivy
/*
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*/
package org.deeplearning4j.arbiter;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.CollectionUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/** ComputationGraphSpace: Defines the space of valid hyperparameters for a ComputationGraph.
* Note that this for essentially fixed graph structures only
*/
public class ComputationGraphSpace extends BaseNetworkSpace {
private List layerSpaces = new ArrayList<>();
private List vertices = new ArrayList<>();
private String[] networkInputs;
private String[] networkOutputs;
private InputType[] inputTypes;
private int numParameters;
//Early stopping configuration / (fixed) number of epochs:
private EarlyStoppingConfiguration earlyStoppingConfiguration;
private ComputationGraphSpace(Builder builder){
super(builder);
this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;
this.layerSpaces = builder.layerList;
this.vertices = builder.vertexList;
this.networkInputs = builder.networkInputs;
this.networkOutputs = builder.networkOutputs;
this.inputTypes = builder.inputTypes;
//Determine total number of parameters:
List list = CollectionUtils.getUnique(collectLeaves());
for(ParameterSpace ps : list) numParameters += ps.numParameters();
}
@Override
public GraphConfiguration getValue(double[] values) {
//Create ComputationGraphConfiguration...
NeuralNetConfiguration.Builder builder = randomGlobalConf(values);
ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder();
graphBuilder.addInputs(this.networkInputs);
graphBuilder.setOutputs(this.networkOutputs);
if(inputTypes != null && inputTypes.length > 0) graphBuilder.setInputTypes(inputTypes);
//Build/add our layers and vertices:
for(LayerConf c : layerSpaces){
org.deeplearning4j.nn.conf.layers.Layer l = c.layerSpace.getValue(values);
graphBuilder.addLayer(c.getLayerName(),l,c.getInputs());
}
for(VertexConf gv : vertices){
graphBuilder.addVertex(gv.getVertexName(),gv.getGraphVertex(),gv.getInputs());
}
if(backprop != null) graphBuilder.backprop(backprop.getValue(values));
if(pretrain != null) graphBuilder.pretrain(pretrain.getValue(values));
if(backpropType != null) graphBuilder.backpropType(backpropType.getValue(values));
if(tbpttFwdLength != null) graphBuilder.tBPTTForwardLength(tbpttFwdLength.getValue(values));
if(tbpttBwdLength != null) graphBuilder.tBPTTBackwardLength(tbpttBwdLength.getValue(values));
ComputationGraphConfiguration configuration = graphBuilder.build();
return new GraphConfiguration(configuration,earlyStoppingConfiguration,numEpochs);
}
@Override
public int numParameters() {
return numParameters;
}
@Override
public List collectLeaves() {
List list = super.collectLeaves();
for(LayerConf lc : layerSpaces){
list.addAll(lc.layerSpace.collectLeaves());
}
if(cnnInputSize != null) list.addAll(cnnInputSize.collectLeaves());
return list;
}
@Override
public String toString(){
StringBuilder sb = new StringBuilder(super.toString());
for(LayerConf conf : layerSpaces){
sb.append("Layer config: \"").append(conf.layerName).append("\", ").append(conf.layerSpace)
.append(", inputs: ").append(conf.inputs == null ? "[]" : Arrays.toString(conf.inputs))
.append("\n");
}
for(VertexConf conf : vertices ){
sb.append("GraphVertex: \"").append(conf.vertexName).append("\", ").append(conf.graphVertex)
.append(", inputs: ").append(conf.inputs == null ? "[]" : Arrays.toString(conf.inputs))
.append("\n");
}
if(earlyStoppingConfiguration != null){
sb.append("Early stopping configuration:").append(earlyStoppingConfiguration.toString()).append("\n");
} else {
sb.append("Training # epochs:").append(numEpochs).append("\n");
}
return sb.toString();
}
@AllArgsConstructor @Data
private static class LayerConf {
private final LayerSpace> layerSpace;
private final String layerName;
private final String[] inputs;
}
@AllArgsConstructor @Data
private static class VertexConf {
private final GraphVertex graphVertex;
private final String vertexName;
private final String[] inputs;
}
public static class Builder extends BaseNetworkSpace.Builder {
protected List layerList = new ArrayList<>();
protected List vertexList = new ArrayList<>();
protected EarlyStoppingConfiguration earlyStoppingConfiguration;
protected String[] networkInputs;
protected String[] networkOutputs;
protected InputType[] inputTypes;
//Need: input types
//Early stopping configuration
//Graph nodes
/** Early stopping configuration (optional). Note if both EarlyStoppingConfiguration and number of epochs is
* present, early stopping will be used in preference.
*/
public Builder earlyStoppingConfiguration(EarlyStoppingConfiguration earlyStoppingConfiguration){
this.earlyStoppingConfiguration = earlyStoppingConfiguration;
return this;
}
public Builder addLayer(String layerName, LayerSpace extends org.deeplearning4j.nn.conf.layers.Layer> layerSpace,
String... layerInputs){
layerList.add(new LayerConf(layerSpace,layerName,layerInputs));
return this;
}
public Builder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs ){
vertexList.add(new VertexConf(vertex,vertexName,vertexInputs));
return this;
}
public Builder addInputs(String... networkInputs){
this.networkInputs = networkInputs;
return this;
}
public Builder setOutputs(String... networkOutputs){
this.networkOutputs = networkOutputs;
return this;
}
public Builder setInputTypes(InputType... inputTypes){
this.inputTypes = inputTypes;
return this;
}
@SuppressWarnings("unchecked")
public ComputationGraphSpace build(){
return new ComputationGraphSpace(this);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy