Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.transferlearning;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.FrozenVertex;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import java.util.*;
@Slf4j
public class TransferLearning {
public static class Builder {
private MultiLayerConfiguration origConf;
private MultiLayerNetwork origModel;
private MultiLayerNetwork editedModel;
private FineTuneConfiguration finetuneConfiguration;
private int frozenTill = -1;
private int popN = 0;
private boolean prepDone = false;
private Set editedLayers = new HashSet<>();
private Map> editedLayersMap =
new HashMap<>();
private Map> nInEditedMap = new HashMap<>();
private List editedParams = new ArrayList<>();
private List editedConfs = new ArrayList<>();
private List appendParams = new ArrayList<>(); //these could be new arrays, and views from origParams
private List appendConfs = new ArrayList<>();
private Map inputPreProcessors = new HashMap<>();
private InputType inputType;
private Boolean validateOutputLayerConfig;
private DataType dataType;
/**
* Multilayer Network to tweak for transfer learning
* @param origModel
*/
public Builder(MultiLayerNetwork origModel) {
this.origModel = origModel;
this.origConf = origModel.getLayerWiseConfigurations().clone();
this.dataType = origModel.getLayerWiseConfigurations().getDataType();
this.inputPreProcessors = origConf.getInputPreProcessors();
}
/**
* Fine tune configurations specified will overwrite the existing configuration if any
* Usage example: specify a learning rate will set specified learning rate on all layers
* Refer to the fineTuneConfiguration class for more details
* @param finetuneConfiguration
* @return Builder
*/
public Builder fineTuneConfiguration(FineTuneConfiguration finetuneConfiguration) {
this.finetuneConfiguration = finetuneConfiguration;
return this;
}
/**
* Specify a layer to set as a "feature extractor"
* The specified layer and the layers preceding it will be "frozen" with parameters staying constant
* @param layerNum
* @return Builder
*/
public Builder setFeatureExtractor(int layerNum) {
this.frozenTill = layerNum;
return this;
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param scheme Weight Init scheme to use for params in layernum and layernum+1
* @return Builder
*/
public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme) {
return nOutReplace(layerNum, nOut, scheme, scheme);
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Distribution to use in conjunction with weight init DISTRIBUTION for params in layernum and layernum+1
* @return Builder
* @see WeightInit DISTRIBUTION
*/
public Builder nOutReplace(int layerNum, int nOut, Distribution dist) {
return nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(dist));
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param scheme Weight Init scheme to use for params in the layerNum
* @param schemeNext Weight Init scheme to use for params in the layerNum+1
* @return Builder
*/
public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, WeightInit schemeNext) {
if(scheme == WeightInit.DISTRIBUTION || schemeNext == WeightInit.DISTRIBUTION) {
throw new UnsupportedOperationException("Not supported!, Use " +
"nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext)) instead!");
}
return nOutReplace(layerNum, nOut, scheme.getWeightInitFunction(), schemeNext.getWeightInitFunction());
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Distribution to use for params in the layerNum
* @param distNext Distribution to use for parmas in layerNum+1
* @return Builder
* @see WeightInitDistribution
*/
public Builder nOutReplace(int layerNum, int nOut, Distribution dist, Distribution distNext) {
return nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext));
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param scheme Weight init scheme to use for params in layerNum
* @param distNext Distribution to use for parmas in layerNum+1
* @return Builder
* @see WeightInitDistribution
*/
public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, Distribution distNext) {
if(scheme == WeightInit.DISTRIBUTION) {
throw new UnsupportedOperationException("Not supported!, Use " +
"nOutReplace(int layerNum, int nOut, Distribution dist, Distribution distNext) instead!");
}
return nOutReplace(layerNum, nOut, scheme.getWeightInitFunction(), new WeightInitDistribution(distNext));
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Distribution to use for parmas in layerNum
* @param schemeNext Weight init scheme to use for params in layerNum+1
* @return Builder
* @see WeightInitDistribution
*/
public Builder nOutReplace(int layerNum, int nOut, Distribution dist, WeightInit schemeNext) {
return nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), schemeNext.getWeightInitFunction());
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param scheme Weight Init scheme to use for params in the layerNum
* @param schemeNext Weight Init scheme to use for params in the layerNum+1
*/
public Builder nOutReplace(int layerNum, int nOut, IWeightInit scheme, IWeightInit schemeNext) {
editedLayers.add(layerNum);
Triple t =
new Triple<>(nOut, scheme, schemeNext);
editedLayersMap.put(layerNum, t);
return this;
}
/**
* Modify the architecture of a vertex layer by changing nIn of the specified layer.
* Note that only the specified layer will be modified - all other layers will not be changed by this call.
*
* @param layerNum The number of the layer to change nIn of
* @param nIn Value of nIn to change to
* @param scheme Weight init scheme to use for params in layerName
* @return Builder
*/
public Builder nInReplace(int layerNum, int nIn, WeightInit scheme) {
return nInReplace(layerNum, nIn, scheme, null);
}
/**
* Modify the architecture of a vertex layer by changing nIn of the specified layer.
* Note that only the specified layer will be modified - all other layers will not be changed by this call.
*
* @param layerNum The number of the layer to change nIn of
* @param nIn Value of nIn to change to
* @param scheme Weight init scheme to use for params in layerName
* @return Builder
*/
public Builder nInReplace(int layerNum, int nIn, WeightInit scheme, Distribution dist) {
return nInReplace(layerNum, nIn, scheme.getWeightInitFunction(dist));
}
/**
* Modify the architecture of a vertex layer by changing nIn of the specified layer.
* Note that only the specified layer will be modified - all other layers will not be changed by this call.
*
* @param layerNum The number of the layer to change nIn of
* @param nIn Value of nIn to change to
* @param scheme Weight init scheme to use for params in layerName
* @return Builder
*/
public Builder nInReplace(int layerNum, int nIn, IWeightInit scheme) {
Pair d = new Pair<>(nIn, scheme);
nInEditedMap.put(layerNum, d);
return this;
}
/**
* Helper method to remove the outputLayer of the net.
* Only one of the two - removeOutputLayer() or removeLayersFromOutput(layerNum) - can be specified
* When removing layers at the very least an output layer should be added with .addLayer(...)
*
* @return Builder
*/
public Builder removeOutputLayer() {
popN = 1;
return this;
}
/**
* Remove last "n" layers of the net
* At least an output layer must be added back in
* @param layerNum number of layers to remove
* @return Builder
*/
public Builder removeLayersFromOutput(int layerNum) {
if (popN == 0) {
popN = layerNum;
} else {
throw new IllegalArgumentException("Remove layers from can only be called once");
}
return this;
}
/**
* Add layers to the net
* Required if layers are removed. Can be called multiple times and layers will be added in the order with which they were called.
* At the very least an outputLayer must be added (output layer should be added last - as per the note on order)
* Learning configs (like updaters, learning rate etc) specified with the layer here will be honored
*
* @param layer layer conf to add (similar to the NeuralNetConfiguration .list().layer(...)
* @return Builder
*/
public Builder addLayer(Layer layer) {
if (!prepDone) {
doPrep();
}
// Use the fineTune config to create the required NeuralNetConfiguration + Layer instances
//instantiate dummy layer to get the params
//Build a nn config builder with settings from finetune. Set layer with the added layer
//Issue: fine tune config has .learningRate(x), then I add a layer with .learningRate(y)...
//We don't want that to be overridden
NeuralNetConfiguration layerConf =
finetuneConfiguration.appliedNeuralNetConfigurationBuilder().layer(layer).build();
val numParams = layer.initializer().numParams(layerConf);
INDArray params;
if (numParams > 0) {
params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), numParams);
org.deeplearning4j.nn.api.Layer someLayer = layer.instantiate(layerConf, null, 0, params, true, dataType);
appendParams.add(someLayer.params());
appendConfs.add(someLayer.conf());
} else {
appendConfs.add(layerConf);
}
return this;
}
/**
* Specify the preprocessor for the added layers
* for cases where they cannot be inferred automatically.
*
* @param processor to be used on the data
* @return Builder
*/
public Builder setInputPreProcessor(int layer, InputPreProcessor processor) {
inputPreProcessors.put(layer, processor);
return this;
}
public Builder validateOutputLayerConfig(boolean validate){
this.validateOutputLayerConfig = validate;
return this;
}
/**
* Returns a model with the fine tune configuration and specified architecture changes.
* .init() need not be called. Can be directly fit.
*
* @return MultiLayerNetwork
*/
public MultiLayerNetwork build() {
if (!prepDone) {
doPrep();
}
editedModel = new MultiLayerNetwork(constructConf(), constructParams());
if (frozenTill != -1) {
org.deeplearning4j.nn.api.Layer[] layers = editedModel.getLayers();
for (int i = frozenTill; i >= 0; i--) {
//Complication here: inner Layer (implementation) NeuralNetConfiguration.layer (config) should keep
// the original layer config. While network NNC should have the frozen layer, for to/from JSON etc
NeuralNetConfiguration origNNC = editedModel.getLayerWiseConfigurations().getConf(i);
NeuralNetConfiguration layerNNC = origNNC.clone();
layers[i].setConf(layerNNC);
layers[i] = new FrozenLayer(layers[i]);
if (origNNC.getVariables() != null) {
List vars = origNNC.variables(true);
origNNC.clearVariables();
layerNNC.clearVariables();
for (String s : vars) {
origNNC.variables(false).add(s);
layerNNC.variables(false).add(s);
}
}
Layer origLayerConf = editedModel.getLayerWiseConfigurations().getConf(i).getLayer();
Layer newLayerConf = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(origLayerConf);
newLayerConf.setLayerName(origLayerConf.getLayerName());
editedModel.getLayerWiseConfigurations().getConf(i).setLayer(newLayerConf);
}
editedModel.setLayers(layers);
}
return editedModel;
}
private void doPrep() {
//first set finetune configs on all layers in model
fineTuneConfigurationBuild();
//editParams gets original model params
for (int i = 0; i < origModel.getnLayers(); i++) {
if (origModel.getLayer(i).numParams() > 0) {
//dup only if params are there
editedParams.add(origModel.getLayer(i).params().dup());
} else {
editedParams.add(origModel.getLayer(i).params());
}
}
//apply changes to nout/nin if any in sorted order and save to editedParams
if (!editedLayers.isEmpty()) {
Integer[] editedLayersSorted = editedLayers.toArray(new Integer[editedLayers.size()]);
Arrays.sort(editedLayersSorted);
for (int i = 0; i < editedLayersSorted.length; i++) {
int layerNum = editedLayersSorted[i];
nOutReplaceBuild(layerNum, editedLayersMap.get(layerNum).getLeft(),
editedLayersMap.get(layerNum).getMiddle(),
editedLayersMap.get(layerNum).getRight());
}
}
if(!nInEditedMap.isEmpty()) {
Integer[] editedLayersSorted = nInEditedMap.keySet().toArray(new Integer[nInEditedMap.size()]);
Arrays.sort(editedLayersSorted);
for (Integer layerNum : editedLayersSorted) {
Pair d = nInEditedMap.get(layerNum);
nInReplaceBuild(layerNum, d.getFirst(), d.getSecond());
}
}
//finally pop layers specified
int i = 0;
while (i < popN) {
Integer layerNum = origModel.getnLayers() - i;
if (inputPreProcessors.containsKey(layerNum)) {
inputPreProcessors.remove(layerNum);
}
editedConfs.remove(editedConfs.size() - 1);
editedParams.remove(editedParams.size() - 1);
i++;
}
prepDone = true;
}
private void fineTuneConfigurationBuild() {
for (int i = 0; i < origConf.getConfs().size(); i++) {
NeuralNetConfiguration layerConf;
if (finetuneConfiguration != null) {
NeuralNetConfiguration nnc = origConf.getConf(i).clone();
finetuneConfiguration.applyToNeuralNetConfiguration(nnc);
layerConf = nnc;
} else {
layerConf = origConf.getConf(i).clone();
}
editedConfs.add(layerConf);
}
}
private void nInReplaceBuild(int layerNum, int nIn, IWeightInit init) {
Preconditions.checkArgument(layerNum >= 0 && layerNum < editedConfs.size(), "Invalid layer index: must be 0 to " +
"numLayers-1 = %s inclusive, got %s", editedConfs.size(), layerNum);
NeuralNetConfiguration layerConf = editedConfs.get(layerNum);
Layer layerImpl = layerConf.getLayer(); //not a clone need to modify nOut in place
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nInReplace can only be applied on FeedForward layers;" +
"got layer of type %s", layerImpl.getClass().getSimpleName());
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(init);
layerImplF.setNIn(nIn);
long numParams = layerImpl.initializer().numParams(layerConf);
INDArray params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), numParams);
org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, dataType);
editedParams.set(layerNum, someLayer.params());
}
private void nOutReplaceBuild(int layerNum, int nOut, IWeightInit scheme, IWeightInit schemeNext) {
Preconditions.checkArgument(layerNum >= 0 && layerNum < editedConfs.size(), "Invalid layer index: must be 0 to " +
"numLayers-1 = %s includive, got %s", editedConfs.size(), layerNum);
NeuralNetConfiguration layerConf = editedConfs.get(layerNum);
Layer layerImpl = layerConf.getLayer(); //not a clone need to modify nOut in place
Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nOutReplace can only be applide on FeedForward layers;" +
"got layer of type %s", layerImpl.getClass().getSimpleName());
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(scheme);
layerImplF.setNOut(nOut);
long numParams = layerImpl.initializer().numParams(layerConf);
INDArray params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), numParams);
org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, dataType);
INDArray params1 = someLayer.params();
editedParams.set(layerNum, params1.reshape(params1.length()));
if (layerNum + 1 < editedConfs.size()) {
layerConf = editedConfs.get(layerNum + 1);
layerImpl = layerConf.getLayer(); //modify in place
if(layerImpl instanceof FeedForwardLayer) {
layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(schemeNext);
layerImplF.setNIn(nOut);
numParams = layerImpl.initializer().numParams(layerConf);
if (numParams > 0) {
params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), numParams);
someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, dataType);
params1 = someLayer.params();
editedParams.set(layerNum + 1, params1.reshape(params1.length()));
}
}
}
}
private INDArray constructParams() {
//some params will be null for subsampling etc
INDArray keepView = null;
for (INDArray aParam : editedParams) {
if (aParam != null) {
if (keepView == null) {
keepView = aParam;
} else {
keepView = Nd4j.hstack(keepView, aParam);
}
}
}
if (!appendParams.isEmpty()) {
INDArray appendView = Nd4j.hstack(appendParams);
return Nd4j.hstack(keepView, appendView);
} else {
return keepView;
}
}
private MultiLayerConfiguration constructConf() {
//use the editedConfs list to make a new config
List allConfs = new ArrayList<>();
allConfs.addAll(editedConfs);
allConfs.addAll(appendConfs);
//Set default layer names, if not set - as per NeuralNetConfiguration.ListBuilder.build()
for (int i = 0; i < allConfs.size(); i++) {
if (allConfs.get(i).getLayer().getLayerName() == null) {
allConfs.get(i).getLayer().setLayerName("layer" + i);
}
}
MultiLayerConfiguration conf = new MultiLayerConfiguration.Builder().inputPreProcessors(inputPreProcessors)
.setInputType(this.inputType).confs(allConfs)
.validateOutputLayerConfig(validateOutputLayerConfig == null ? true : validateOutputLayerConfig)
.dataType(origConf.getDataType())
.build();
if (finetuneConfiguration != null) {
finetuneConfiguration.applyToMultiLayerConfiguration(conf);
}
return conf;
}
}
public static class GraphBuilder {
private ComputationGraph origGraph;
private ComputationGraphConfiguration origConfig;
private FineTuneConfiguration fineTuneConfiguration;
private ComputationGraphConfiguration.GraphBuilder editedConfigBuilder;
private String[] frozenOutputAt;
private boolean hasFrozen = false;
private Set editedVertices = new HashSet<>();
private WorkspaceMode workspaceMode;
private Boolean validateOutputLayerConfig = null;
private Map nInFromNewConfig = new HashMap<>();
/**
* Computation Graph to tweak for transfer learning
* @param origGraph
*/
public GraphBuilder(ComputationGraph origGraph) {
this.origGraph = origGraph;
this.origConfig = origGraph.getConfiguration().clone();
}
/**
* Set parameters to selectively override existing learning parameters
* Usage eg. specify a lower learning rate. This will get applied to all layers
* @param fineTuneConfiguration
* @return GraphBuilder
*/
public GraphBuilder fineTuneConfiguration(FineTuneConfiguration fineTuneConfiguration) {
this.fineTuneConfiguration = fineTuneConfiguration;
this.editedConfigBuilder = new ComputationGraphConfiguration.GraphBuilder(origConfig,
fineTuneConfiguration.appliedNeuralNetConfigurationBuilder());
Map vertices = this.editedConfigBuilder.getVertices();
for (Map.Entry gv : vertices.entrySet()) {
if (gv.getValue() instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) gv.getValue();
NeuralNetConfiguration nnc = lv.getLayerConf().clone();
fineTuneConfiguration.applyToNeuralNetConfiguration(nnc);
vertices.put(gv.getKey(), new LayerVertex(nnc, lv.getPreProcessor()));
nnc.getLayer().setLayerName(gv.getKey());
}
}
return this;
}
/**
* Specify a layer vertex to set as a "feature extractor"
* The specified layer vertex and the layers on the path from an input vertex to it will be "frozen" with parameters staying constant
* @param layerName
* @return Builder
*/
public GraphBuilder setFeatureExtractor(String... layerName) {
this.hasFrozen = true;
this.frozenOutputAt = layerName;
return this;
}
/**
* Modify the architecture of a vertex layer by changing nOut
* Note this will also affect the vertex layer that follows the layer specified, unless it is the output layer
* Currently does not support modifying nOut of layers that feed into non-layer vertices like merge, subset etc
* To modify nOut for such vertices use remove vertex, followed by add vertex
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerName The name of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param scheme Weight init scheme to use for params in layerName and the layers following it
* @return GraphBuilder
* @see WeightInit DISTRIBUTION
*/
public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme) {
return nOutReplace(layerName, nOut, scheme, scheme);
}
/**
* Modify the architecture of a vertex layer by changing nOut
* Note this will also affect the vertex layer that follows the layer specified, unless it is the output layer
* Currently does not support modifying nOut of layers that feed into non-layer vertices like merge, subset etc
* To modify nOut for such vertices use remove vertex, followed by add vertex
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerName The name of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Weight distribution scheme to use
* @return GraphBuilder
* @see WeightInit DISTRIBUTION
*/
public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist) {
return nOutReplace(layerName, nOut, dist, dist);
}
/**
* Modified nOut of specified layer. Also affects layers following layerName unless they are output layers
* @param layerName The name of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Weight distribution scheme to use for layerName
* @param distNext Weight distribution scheme for layers following layerName
* @return GraphBuilder
* @see WeightInit DISTRIBUTION
*/
public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist, Distribution distNext) {
return nOutReplace(layerName, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext));
}
public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, Distribution dist) {
if(scheme == WeightInit.DISTRIBUTION) {
throw new UnsupportedOperationException("Not supported!, Use " +
"nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext)) instead!");
}
return nOutReplace(layerName, nOut, scheme.getWeightInitFunction(), new WeightInitDistribution(dist));
}
public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist, WeightInit scheme) {
if(scheme == WeightInit.DISTRIBUTION) {
throw new UnsupportedOperationException("Not supported!, Use " +
"nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext)) instead!");
}
return nOutReplace(layerName, nOut, new WeightInitDistribution(dist), scheme.getWeightInitFunction());
}
public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, WeightInit schemeNext) {
if(scheme == WeightInit.DISTRIBUTION || schemeNext == WeightInit.DISTRIBUTION) {
throw new UnsupportedOperationException("Not supported!, Use " +
"nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext)) instead!");
}
return nOutReplace(layerName, nOut, scheme.getWeightInitFunction(), schemeNext.getWeightInitFunction());
}
/**
* Modify the architecture of a vertex layer by changing nIn of the specified layer.
* Note that only the specified layer will be modified - all other layers will not be changed by this call.
*
* @param layerName The name of the layer to change nIn of
* @param nIn Value of nIn to change to
* @param scheme Weight init scheme to use for params in layerName
* @return GraphBuilder
*/
public GraphBuilder nInReplace(String layerName, int nIn, WeightInit scheme) {
return nInReplace(layerName, nIn, scheme, null);
}
public GraphBuilder validateOutputLayerConfig(boolean validateOutputLayerConfig){
this.validateOutputLayerConfig = validateOutputLayerConfig;
return this;
}
/**
* Modify the architecture of a vertex layer by changing nIn of the specified layer.
* Note that only the specified layer will be modified - all other layers will not be changed by this call.
*
* @param layerName The name of the layer to change nIn of
* @param nIn Value of nIn to change to
* @param scheme Weight init scheme to use for params in layerName and the layers following it
* @return GraphBuilder
*/
public GraphBuilder nInReplace(String layerName, int nIn, WeightInit scheme, Distribution dist) {
return nInReplace(layerName, nIn, scheme.getWeightInitFunction(dist));
}
/**
* Modify the architecture of a vertex layer by changing nIn of the specified layer.
* Note that only the specified layer will be modified - all other layers will not be changed by this call.
*
* @param layerName The name of the layer to change nIn of
* @param nIn Value of nIn to change to
* @param scheme Weight init scheme to use for params in layerName and the layers following it
* @return GraphBuilder
*/
public GraphBuilder nInReplace(String layerName, int nIn, IWeightInit scheme) {
Preconditions.checkState(origGraph.getVertex(layerName) != null, "Layer with name %s not found",
layerName);
Preconditions.checkState(origGraph.getVertex(layerName).hasLayer(), "nInReplace can only be applied" +
" on vertices with layers. Vertex %s does not have a layer", layerName);
initBuilderIfReq();
NeuralNetConfiguration layerConf = origGraph.getLayer(layerName).conf();
Layer layerImpl = layerConf.getLayer().clone();
Preconditions.checkState(layerImpl instanceof FeedForwardLayer, "Can only use nInReplace on FeedForward layers;" +
"got layer of type %s for layer name %s", layerImpl.getClass().getSimpleName(), layerName);
layerImpl.resetLayerDefaultConfig();
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(scheme);
layerImplF.setNIn(nIn);
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
&& nInFromNewConfig.containsKey(layerName)){
Layer l = ((LayerVertex)editedConfigBuilder.getVertices().get(layerName)).getLayerConf().getLayer();
if(l instanceof FeedForwardLayer){
layerImplF.setNIn(nInFromNewConfig.get(layerName));
}
}
editedConfigBuilder.removeVertex(layerName, false);
LayerVertex lv = (LayerVertex) origConfig.getVertices().get(layerName);
String[] lvInputs = origConfig.getVertexInputs().get(layerName).toArray(new String[0]);
editedConfigBuilder.addLayer(layerName, layerImpl, lv.getPreProcessor(), lvInputs);
editedVertices.add(layerName);
return this;
}
private GraphBuilder nOutReplace(String layerName, int nOut, IWeightInit scheme, IWeightInit schemeNext) {
initBuilderIfReq();
if (origGraph.getVertex(layerName).hasLayer()) {
NeuralNetConfiguration layerConf = origGraph.getLayer(layerName).conf();
Layer layerImpl = layerConf.getLayer().clone();
layerImpl.resetLayerDefaultConfig();
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(scheme);
layerImplF.setNOut(nOut);
if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex
&& nInFromNewConfig.containsKey(layerName)){
Layer l = ((LayerVertex)editedConfigBuilder.getVertices().get(layerName)).getLayerConf().getLayer();
if(l instanceof FeedForwardLayer){
layerImplF.setNIn(nInFromNewConfig.get(layerName));
}
}
editedConfigBuilder.removeVertex(layerName, false);
LayerVertex lv = (LayerVertex) origConfig.getVertices().get(layerName);
String[] lvInputs = origConfig.getVertexInputs().get(layerName).toArray(new String[0]);
editedConfigBuilder.addLayer(layerName, layerImpl, lv.getPreProcessor(), lvInputs);
editedVertices.add(layerName);
//collect other vertices that have this vertex as inputs
List fanoutVertices = new ArrayList<>();
for (Map.Entry> entry : origConfig.getVertexInputs().entrySet()) {
String currentVertex = entry.getKey();
if (!currentVertex.equals(layerName)) {
if (entry.getValue().contains(layerName)) {
fanoutVertices.add(currentVertex);
}
}
}
//change nIn of fanout
for (String fanoutVertexName : fanoutVertices) {
if (!origGraph.getVertex(fanoutVertexName).hasLayer()) {
throw new UnsupportedOperationException(
"Cannot modify nOut of a layer vertex that feeds non-layer vertices. Use removeVertexKeepConnections followed by addVertex instead");
}
layerConf = origGraph.getLayer(fanoutVertexName).conf();
if(!(layerConf.getLayer() instanceof FeedForwardLayer))
continue;
layerImpl = layerConf.getLayer().clone();
layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setWeightInitFn(schemeNext);
layerImplF.setNIn(nOut);
nInFromNewConfig.put(fanoutVertexName, nOut);
editedConfigBuilder.removeVertex(fanoutVertexName, false);
lv = (LayerVertex) origConfig.getVertices().get(fanoutVertexName);
lvInputs = origConfig.getVertexInputs().get(fanoutVertexName).toArray(new String[0]);
editedConfigBuilder.addLayer(fanoutVertexName, layerImpl, lv.getPreProcessor(), lvInputs);
editedVertices.add(fanoutVertexName);
if(validateOutputLayerConfig != null) {
editedConfigBuilder.validateOutputLayerConfig(validateOutputLayerConfig);
}
}
} else {
throw new IllegalArgumentException("noutReplace can only be applied to layer vertices. " + layerName
+ " is not a layer vertex");
}
return this;
}
/**
* Remove the specified vertex from the computation graph but keep it's connections.
* Note the expectation here is to then add back another vertex with the same name or else the graph will be left in an invalid state
* Possibly with references to vertices that no longer exist
* @param outputName
* @return
*/
public GraphBuilder removeVertexKeepConnections(String outputName) {
initBuilderIfReq();
editedConfigBuilder.removeVertex(outputName, false);
return this;
}
/**
* Remove specified vertex and it's connections from the computation graph
* @param vertexName
* @return
*/
public GraphBuilder removeVertexAndConnections(String vertexName) {
initBuilderIfReq();
editedConfigBuilder.removeVertex(vertexName, true);
return this;
}
/**
* Add a layer of the specified configuration to the computation graph
* @param layerName
* @param layer
* @param layerInputs
* @return
*/
public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) {
initBuilderIfReq();
editedConfigBuilder.addLayer(layerName, layer, null, layerInputs);
editedVertices.add(layerName);
return this;
}
/**
* Add a layer with a specified preprocessor
* @param layerName
* @param layer
* @param preProcessor
* @param layerInputs
* @return
*/
public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor,
String... layerInputs) {
initBuilderIfReq();
editedConfigBuilder.addLayer(layerName, layer, preProcessor, layerInputs);
editedVertices.add(layerName);
return this;
}
/**
* Add a vertex of the given configuration to the computation graph
* @param vertexName
* @param vertex
* @param vertexInputs
* @return
*/
public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) {
initBuilderIfReq();
editedConfigBuilder.addVertex(vertexName, vertex, vertexInputs);
editedVertices.add(vertexName);
return this;
}
/**
* Set outputs to the computation graph, will add to ones that are existing
* Also determines the order, like in ComputationGraphConfiguration
* @param outputNames
* @return
*/
public GraphBuilder setOutputs(String... outputNames) {
initBuilderIfReq();
editedConfigBuilder.setOutputs(outputNames);
return this;
}
private void initBuilderIfReq() {
if (editedConfigBuilder == null) {
//No fine tune config has been set. One isn't required, but we need one to create the editedConfigBuilder
//So: create an empty finetune config, which won't override anything
//but keep the seed
fineTuneConfiguration(new FineTuneConfiguration.Builder()
.seed(origConfig.getDefaultConfiguration().getSeed()).build());
}
}
/**
* Sets new inputs for the computation graph. This method will remove any
* pre-existing inputs.
* @param inputs String names of each graph input.
* @return {@code GraphBuilder} instance.
*/
public GraphBuilder setInputs(String... inputs) {
editedConfigBuilder.setNetworkInputs(Arrays.asList(inputs));
return this;
}
/**
* Sets the input type of corresponding inputs.
* @param inputTypes The type of input (such as convolutional).
* @return {@code GraphBuilder} instance.
*/
public GraphBuilder setInputTypes(InputType... inputTypes) {
editedConfigBuilder.setInputTypes(inputTypes);
return this;
}
public GraphBuilder addInputs(String... inputNames) {
editedConfigBuilder.addInputs(inputNames);
return this;
}
public GraphBuilder setWorkspaceMode(WorkspaceMode workspaceMode) {
this.workspaceMode = workspaceMode;
return this;
}
/**
* Returns a computation graph build to specifications.
* Init has been internally called. Can be fit directly.
* @return Computation graph
*/
public ComputationGraph build() {
initBuilderIfReq();
ComputationGraphConfiguration newConfig = editedConfigBuilder
.validateOutputLayerConfig(validateOutputLayerConfig == null ? true : validateOutputLayerConfig).build();
if (this.workspaceMode != null)
newConfig.setTrainingWorkspaceMode(workspaceMode);
ComputationGraph newGraph = new ComputationGraph(newConfig);
newGraph.init();
int[] topologicalOrder = newGraph.topologicalSortOrder();
org.deeplearning4j.nn.graph.vertex.GraphVertex[] vertices = newGraph.getVertices();
if (!editedVertices.isEmpty()) {
//set params from orig graph as necessary to new graph
for (int i = 0; i < topologicalOrder.length; i++) {
if (!vertices[topologicalOrder[i]].hasLayer())
continue;
org.deeplearning4j.nn.api.Layer layer = vertices[topologicalOrder[i]].getLayer();
String layerName = vertices[topologicalOrder[i]].getVertexName();
long range = layer.numParams();
if (range <= 0)
continue; //some layers have no params
if (editedVertices.contains(layerName))
continue; //keep the changed params
INDArray origParams = origGraph.getLayer(layerName).params();
layer.setParams(origParams.dup()); //copy over origGraph params
}
} else {
newGraph.setParams(origGraph.params());
}
//Freeze layers as necessary. Note: we can't simply say "everything before frozen layer X needs to be frozen
// also" as this won't always work. For example, in1->A->C, in2->B->C, freeze B; A shouldn't be frozen, even
// if A is before B in the topological sort order.
//How it should be handled: use the graph structure + topological sort order.
// If a vertex is marked to be frozen: freeze it
// Any descendants of a frozen layer should also be frozen
if (hasFrozen) {
//Store all frozen layers, and any vertices inheriting from said layers
Set allFrozen = new HashSet<>();
Collections.addAll(allFrozen, frozenOutputAt);
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
org.deeplearning4j.nn.graph.vertex.GraphVertex gv = vertices[topologicalOrder[i]];
if (allFrozen.contains(gv.getVertexName())) {
if (gv.hasLayer()) {
//Need to freeze this layer - both the layer implementation, and the layer configuration
org.deeplearning4j.nn.api.Layer l = gv.getLayer();
gv.setLayerAsFrozen();
String layerName = gv.getVertexName();
LayerVertex currLayerVertex = (LayerVertex) newConfig.getVertices().get(layerName);
Layer origLayerConf = currLayerVertex.getLayerConf().getLayer();
Layer newLayerConf = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(origLayerConf);
newLayerConf.setLayerName(origLayerConf.getLayerName());
//Complication here(and reason for clone on next line): inner Layer (implementation)
// NeuralNetConfiguration.layer (config) should keep the original layer config. While network
// NNC should have the frozen layer
NeuralNetConfiguration newNNC = currLayerVertex.getLayerConf().clone();
currLayerVertex.setLayerConf(newNNC);
currLayerVertex.getLayerConf().setLayer(newLayerConf);
//Make sure the underlying layer doesn't change:
List vars = currLayerVertex.getLayerConf().variables(true);
currLayerVertex.getLayerConf().clearVariables();
for (String s : vars) {
newNNC.variables(false).add(s);
}
//We also need to place the layer in the CompGraph Layer[] (replacing the old one)
//This could no doubt be done more efficiently
org.deeplearning4j.nn.api.Layer[] layers = newGraph.getLayers();
for (int j = 0; j < layers.length; j++) {
if (layers[j] == l) {
layers[j] = gv.getLayer(); //Place the new frozen layer to replace the original layer
break;
}
}
} else {
if(!(gv instanceof InputVertex)) {
GraphVertex currVertexConf = newConfig.getVertices().get(gv.getVertexName());
GraphVertex newVertexConf = new org.deeplearning4j.nn.conf.graph.FrozenVertex(currVertexConf);
newConfig.getVertices().put(gv.getVertexName(), newVertexConf);
vertices[topologicalOrder[i]] = new FrozenVertex(gv);
}
}
//Also: mark any inputs as to be frozen also
VertexIndices[] inputs = gv.getInputVertices();
if (inputs != null && inputs.length > 0) {
for (int j = 0; j < inputs.length; j++) {
int inputVertexIdx = inputs[j].getVertexIndex();
String alsoFreeze = vertices[inputVertexIdx].getVertexName();
allFrozen.add(alsoFreeze);
}
}
}
}
newGraph.initGradientsView();
}
return newGraph;
}
}
}