All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.deeplearning4j.arbiter.ComputationGraphSpace Maven / Gradle / Ivy
/*-
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*/
package org.deeplearning4j.arbiter;
import lombok.*;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper;
import org.deeplearning4j.arbiter.util.LeafUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.annotation.JsonTypeName;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* ComputationGraphSpace: Defines the space of valid hyperparameters for a ComputationGraph.
* Note that this for fixed graph structures only
*
* @author Alex Black
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON ser/de
@Data
@EqualsAndHashCode(callSuper = true)
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@JsonTypeName("ComputationGraphSpace")
public class ComputationGraphSpace extends BaseNetworkSpace {
@JsonProperty
protected List layerSpaces = new ArrayList<>();
@JsonProperty
protected List vertices = new ArrayList<>();
@JsonProperty
protected String[] networkInputs;
@JsonProperty
protected String[] networkOutputs;
@JsonProperty
protected ParameterSpace inputTypes;
@JsonProperty
protected int numParameters;
//Early stopping configuration / (fixed) number of epochs:
protected EarlyStoppingConfiguration earlyStoppingConfiguration;
protected ComputationGraphSpace(Builder builder) {
super(builder);
this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;
this.layerSpaces = builder.layerList;
this.vertices = builder.vertexList;
this.networkInputs = builder.networkInputs;
this.networkOutputs = builder.networkOutputs;
this.inputTypes = builder.inputTypes;
//Determine total number of parameters:
List list = LeafUtils.getUniqueObjects(collectLeaves());
for (ParameterSpace ps : list)
numParameters += ps.numParameters();
}
@Override
public GraphConfiguration getValue(double[] values) {
//Create ComputationGraphConfiguration...
NeuralNetConfiguration.Builder builder = randomGlobalConf(values);
ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder();
graphBuilder.addInputs(this.networkInputs);
graphBuilder.setOutputs(this.networkOutputs);
if (inputTypes != null)
graphBuilder.setInputTypes(inputTypes.getValue(values));
//Build/add our layers and vertices:
for (LayerConf c : layerSpaces) {
org.deeplearning4j.nn.conf.layers.Layer l = c.layerSpace.getValue(values);
graphBuilder.addLayer(c.getLayerName(), l, c.getInputs());
}
for (VertexConf gv : vertices) {
graphBuilder.addVertex(gv.getVertexName(), gv.getGraphVertex(), gv.getInputs());
}
if (backprop != null)
graphBuilder.backprop(backprop.getValue(values));
if (pretrain != null)
graphBuilder.pretrain(pretrain.getValue(values));
if (backpropType != null)
graphBuilder.backpropType(backpropType.getValue(values));
if (tbpttFwdLength != null)
graphBuilder.tBPTTForwardLength(tbpttFwdLength.getValue(values));
if (tbpttBwdLength != null)
graphBuilder.tBPTTBackwardLength(tbpttBwdLength.getValue(values));
ComputationGraphConfiguration configuration = graphBuilder.build();
return new GraphConfiguration(configuration, earlyStoppingConfiguration, numEpochs);
}
@Override
public int numParameters() {
return numParameters;
}
@Override
public List collectLeaves() {
List list = super.collectLeaves();
for (LayerConf lc : layerSpaces) {
list.addAll(lc.layerSpace.collectLeaves());
}
if (inputTypes != null)
list.add(inputTypes);
return list;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder(super.toString());
for (LayerConf conf : layerSpaces) {
sb.append("Layer config: \"").append(conf.layerName).append("\", ").append(conf.layerSpace)
.append(", inputs: ").append(conf.inputs == null ? "[]" : Arrays.toString(conf.inputs))
.append("\n");
}
for (VertexConf conf : vertices) {
sb.append("GraphVertex: \"").append(conf.vertexName).append("\", ").append(conf.graphVertex)
.append(", inputs: ").append(conf.inputs == null ? "[]" : Arrays.toString(conf.inputs))
.append("\n");
}
if (earlyStoppingConfiguration != null) {
sb.append("Early stopping configuration:").append(earlyStoppingConfiguration.toString()).append("\n");
} else {
sb.append("Training # epochs:").append(numEpochs).append("\n");
}
if (inputTypes != null) {
sb.append("Input types: ").append(inputTypes).append("\n");
}
return sb.toString();
}
@AllArgsConstructor
@Data
@NoArgsConstructor //For Jackson JSON
protected static class VertexConf {
protected GraphVertex graphVertex;
protected String vertexName;
protected String[] inputs;
}
public static class Builder extends BaseNetworkSpace.Builder {
protected List layerList = new ArrayList<>();
protected List vertexList = new ArrayList<>();
protected EarlyStoppingConfiguration earlyStoppingConfiguration;
protected String[] networkInputs;
protected String[] networkOutputs;
protected ParameterSpace inputTypes;
//Need: input types
//Early stopping configuration
//Graph nodes
/**
* Early stopping configuration (optional). Note if both EarlyStoppingConfiguration and number of epochs is
* present, early stopping will be used in preference.
*/
public Builder earlyStoppingConfiguration(
EarlyStoppingConfiguration earlyStoppingConfiguration) {
this.earlyStoppingConfiguration = earlyStoppingConfiguration;
return this;
}
public Builder addLayer(String layerName, LayerSpace extends Layer> layerSpace, String... layerInputs) {
layerList.add(new LayerConf(layerSpace, layerName, layerInputs, new FixedValue<>(1), false));
return this;
}
public Builder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) {
vertexList.add(new VertexConf(vertex, vertexName, vertexInputs));
return this;
}
public Builder addInputs(String... networkInputs) {
this.networkInputs = networkInputs;
return this;
}
public Builder setOutputs(String... networkOutputs) {
this.networkOutputs = networkOutputs;
return this;
}
public Builder setInputTypes(InputType... inputTypes) {
return setInputTypes(new FixedValue(inputTypes));
}
public Builder setInputTypes(ParameterSpace inputTypes) {
this.inputTypes = inputTypes;
return this;
}
@SuppressWarnings("unchecked")
public ComputationGraphSpace build() {
return new ComputationGraphSpace(this);
}
}
/**
* Instantiate a computation graph space from
* a raw json string
* @param json
* @return
*/
public static ComputationGraphSpace fromJson(String json) {
try {
return JsonMapper.getMapper().readValue(json, ComputationGraphSpace.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Instantiate a computation graph space
* from a raw yaml string
* @param yaml
* @return
*/
public static ComputationGraphSpace fromYaml(String yaml) {
try {
return YamlMapper.getMapper().readValue(yaml, ComputationGraphSpace.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}