org.deeplearning4j.nn.transferlearning.TransferLearning Maven / Gradle / Ivy
package org.deeplearning4j.nn.transferlearning;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
/**
* The transfer learning API can be used to modify the architecture or the learning parameters of an existing multilayernetwork or computation graph.
* It allows one to
* - change nOut of an existing layer
* - remove and add existing layers/vertices
* - fine tune learning configuration (learning rate, updater etc)
* - hold parameters for specified layers as a constant
*/
@Slf4j
public class TransferLearning {
public static class Builder {
private MultiLayerConfiguration origConf;
private MultiLayerNetwork origModel;
private MultiLayerNetwork editedModel;
private FineTuneConfiguration finetuneConfiguration;
private int frozenTill = -1;
private int popN = 0;
private boolean prepDone = false;
private Set editedLayers = new HashSet<>();
private Map, Pair>> editedLayersMap =
new HashMap<>();
private List editedParams = new ArrayList<>();
private List editedConfs = new ArrayList<>();
private List appendParams = new ArrayList<>(); //these could be new arrays, and views from origParams
private List appendConfs = new ArrayList<>();
private Map inputPreProcessors = new HashMap<>();
private InputType inputType;
/**
* Multilayer Network to tweak for transfer learning
* @param origModel
*/
public Builder(MultiLayerNetwork origModel) {
this.origModel = origModel;
this.origConf = origModel.getLayerWiseConfigurations().clone();
this.inputPreProcessors = origConf.getInputPreProcessors();
}
/**
* Fine tune configurations specified will overwrite the existing configuration if any
* Usage example: specify a learning rate will set specified learning rate on all layers
* Refer to the fineTuneConfiguration class for more details
* @param finetuneConfiguration
* @return Builder
*/
public Builder fineTuneConfiguration(FineTuneConfiguration finetuneConfiguration) {
this.finetuneConfiguration = finetuneConfiguration;
return this;
}
/**
* Specify a layer to set as a "feature extractor"
* The specified layer and the layers preceding it will be "frozen" with parameters staying constant
* @param layerNum
* @return Builder
*/
public Builder setFeatureExtractor(int layerNum) {
this.frozenTill = layerNum;
return this;
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param scheme Weight Init scheme to use for params in layernum and layernum+1
* @return Builder
*/
public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme) {
return nOutReplace(layerNum, nOut, scheme, scheme, null, null);
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Distribution to use in conjunction with weight init DISTRIBUTION for params in layernum and layernum+1
* @return Builder
* @see org.deeplearning4j.nn.weights.WeightInit DISTRIBUTION
*/
public Builder nOutReplace(int layerNum, int nOut, Distribution dist) {
return nOutReplace(layerNum, nOut, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, dist, dist);
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param scheme Weight Init scheme to use for params in the layerNum
* @param schemeNext Weight Init scheme to use for params in the layerNum+1
* @return Builder
*/
public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, WeightInit schemeNext) {
return nOutReplace(layerNum, nOut, scheme, schemeNext, null, null);
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Distribution to use for params in the layerNum
* @param distNext Distribution to use for parmas in layerNum+1
* @return Builder
* @see org.deeplearning4j.nn.weights.WeightInit DISTRIBUTION
*/
public Builder nOutReplace(int layerNum, int nOut, Distribution dist, Distribution distNext) {
return nOutReplace(layerNum, nOut, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, dist, distNext);
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param scheme Weight init scheme to use for params in layerNum
* @param distNext Distribution to use for parmas in layerNum+1
* @return Builder
* @see org.deeplearning4j.nn.weights.WeightInit DISTRIBUTION
*/
public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, Distribution distNext) {
return nOutReplace(layerNum, nOut, scheme, WeightInit.DISTRIBUTION, null, distNext);
}
/**
* Modify the architecture of a layer by changing nOut
* Note this will also affect the layer that follows the layer specified, unless it is the output layer
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerNum The index of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Distribution to use for parmas in layerNum
* @param schemeNext Weight init scheme to use for params in layerNum+1
* @return Builder
* @see org.deeplearning4j.nn.weights.WeightInit DISTRIBUTION
*/
public Builder nOutReplace(int layerNum, int nOut, Distribution dist, WeightInit schemeNext) {
return nOutReplace(layerNum, nOut, WeightInit.DISTRIBUTION, schemeNext, dist, null);
}
private Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, WeightInit schemeNext, Distribution dist,
Distribution distNext) {
editedLayers.add(layerNum);
ImmutableTriple, Pair> t =
new ImmutableTriple(nOut, new ImmutablePair<>(scheme, dist),
new ImmutablePair<>(schemeNext, distNext));
editedLayersMap.put(layerNum, t);
return this;
}
/**
* Helper method to remove the outputLayer of the net.
* Only one of the two - removeOutputLayer() or removeLayersFromOutput(layerNum) - can be specified
* When removing layers at the very least an output layer should be added with .addLayer(...)
*
* @return Builder
*/
public Builder removeOutputLayer() {
popN = 1;
return this;
}
/**
* Remove last "n" layers of the net
* At least an output layer must be added back in
* @param layerNum number of layers to remove
* @return Builder
*/
public Builder removeLayersFromOutput(int layerNum) {
if (popN == 0) {
popN = layerNum;
} else {
throw new IllegalArgumentException("Remove layers from can only be called once");
}
return this;
}
/**
* Add layers to the net
* Required if layers are removed. Can be called multiple times and layers will be added in the order with which they were called.
* At the very least an outputLayer must be added (output layer should be added last - as per the note on order)
* Learning configs (like updaters, learning rate etc) specified with the layer here will be honored
*
* @param layer layer conf to add (similar to the NeuralNetConfiguration .list().layer(...)
* @return Builder
*/
public Builder addLayer(Layer layer) {
if (!prepDone) {
doPrep();
}
// Use the fineTune config to create the required NeuralNetConfiguration + Layer instances
//instantiate dummy layer to get the params
//Build a nn config builder with settings from finetune. Set layer with the added layer
NeuralNetConfiguration layerConf =
finetuneConfiguration.appliedNeuralNetConfigurationBuilder().layer(layer).build();
int numParams = layer.initializer().numParams(layerConf);
INDArray params;
if (numParams > 0) {
params = Nd4j.create(1, numParams);
org.deeplearning4j.nn.api.Layer someLayer = layer.instantiate(layerConf, null, 0, params, true);
appendParams.add(someLayer.params());
appendConfs.add(someLayer.conf());
} else {
appendConfs.add(layerConf);
}
return this;
}
/**
* Specify the preprocessor for the added layers
* for cases where they cannot be inferred automatically.
*
* @param processor to be used on the data
* @return Builder
*/
public Builder setInputPreProcessor(int layer, InputPreProcessor processor) {
inputPreProcessors.put(layer, processor);
return this;
}
/**
* Returns a model with the fine tune configuration and specified architecture changes.
* .init() need not be called. Can be directly fit.
*
* @return MultiLayerNetwork
*/
public MultiLayerNetwork build() {
if (!prepDone) {
doPrep();
}
editedModel = new MultiLayerNetwork(constructConf(), constructParams());
if (frozenTill != -1) {
org.deeplearning4j.nn.api.Layer[] layers = editedModel.getLayers();
for (int i = frozenTill; i >= 0; i--) {
//unchecked?
layers[i] = new FrozenLayer(layers[i]);
}
editedModel.setLayers(layers);
}
return editedModel;
}
private void doPrep() {
//first set finetune configs on all layers in model
fineTuneConfigurationBuild();
//editParams gets original model params
for (int i = 0; i < origModel.getnLayers(); i++) {
if (origModel.getLayer(i).numParams() > 0) {
//dup only if params are there
editedParams.add(origModel.getLayer(i).params().dup());
} else {
editedParams.add(origModel.getLayer(i).params());
}
}
//apply changes to nout/nin if any in sorted order and save to editedParams
if (!editedLayers.isEmpty()) {
Integer[] editedLayersSorted = editedLayers.toArray(new Integer[editedLayers.size()]);
Arrays.sort(editedLayersSorted);
for (int i = 0; i < editedLayersSorted.length; i++) {
int layerNum = editedLayersSorted[i];
nOutReplaceBuild(layerNum, editedLayersMap.get(layerNum).getLeft(),
editedLayersMap.get(layerNum).getMiddle(),
editedLayersMap.get(layerNum).getRight());
}
}
//finally pop layers specified
int i = 0;
while (i < popN) {
Integer layerNum = origModel.getnLayers() - i;
if (inputPreProcessors.containsKey(layerNum)) {
inputPreProcessors.remove(layerNum);
}
editedConfs.remove(editedConfs.size() - 1);
editedParams.remove(editedParams.size() - 1);
i++;
}
prepDone = true;
}
private void fineTuneConfigurationBuild() {
for (int i = 0; i < origConf.getConfs().size(); i++) {
NeuralNetConfiguration layerConf;
if (finetuneConfiguration != null) {
NeuralNetConfiguration nnc = origConf.getConf(i).clone();
finetuneConfiguration.applyToNeuralNetConfiguration(nnc);
layerConf = nnc;
} else {
layerConf = origConf.getConf(i).clone();
}
editedConfs.add(layerConf);
}
}
private void nOutReplaceBuild(int layerNum, int nOut, Pair schemedist,
Pair schemedistNext) {
NeuralNetConfiguration layerConf = editedConfs.get(layerNum);
Layer layerImpl = layerConf.getLayer(); //not a clone need to modify nOut in place
layerImpl.setWeightInit(schemedist.getLeft());
layerImpl.setDist(schemedist.getRight());
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setNOut(nOut);
int numParams = layerImpl.initializer().numParams(layerConf);
INDArray params = Nd4j.create(1, numParams);
org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true);
editedParams.set(layerNum, someLayer.params());
if (layerNum + 1 < editedConfs.size()) {
layerConf = editedConfs.get(layerNum + 1);
layerImpl = layerConf.getLayer(); //modify in place
layerImpl.setWeightInit(schemedistNext.getLeft());
layerImpl.setDist(schemedistNext.getRight());
layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setNIn(nOut);
numParams = layerImpl.initializer().numParams(layerConf);
if (numParams > 0) {
params = Nd4j.create(1, numParams);
someLayer = layerImpl.instantiate(layerConf, null, 0, params, true);
editedParams.set(layerNum + 1, someLayer.params());
}
}
}
private INDArray constructParams() {
//some params will be null for subsampling etc
INDArray keepView = null;
for (INDArray aParam : editedParams) {
if (aParam != null) {
if (keepView == null) {
keepView = aParam;
} else {
keepView = Nd4j.hstack(keepView, aParam);
}
}
}
if (!appendParams.isEmpty()) {
INDArray appendView = Nd4j.hstack(appendParams);
return Nd4j.hstack(keepView, appendView);
} else {
return keepView;
}
}
private MultiLayerConfiguration constructConf() {
//use the editedConfs list to make a new config
List allConfs = new ArrayList<>();
allConfs.addAll(editedConfs);
allConfs.addAll(appendConfs);
//Set default layer names, if not set - as per NeuralNetConfiguration.ListBuilder.build()
for (int i = 0; i < allConfs.size(); i++) {
if (allConfs.get(i).getLayer().getLayerName() == null) {
allConfs.get(i).getLayer().setLayerName("layer" + i);
}
}
// return new MultiLayerConfiguration.Builder().backprop(backprop).inputPreProcessors(inputPreProcessors).
// pretrain(pretrain).backpropType(backpropType).tBPTTForwardLength(tbpttFwdLength)
// .tBPTTBackwardLength(tbpttBackLength)
// .setInputType(this.inputType)
// .confs(allConfs).build();
MultiLayerConfiguration conf = new MultiLayerConfiguration.Builder().inputPreProcessors(inputPreProcessors)
.setInputType(this.inputType).confs(allConfs).build();
if (finetuneConfiguration != null) {
finetuneConfiguration.applyToMultiLayerConfiguration(conf);
}
return conf;
}
}
public static class GraphBuilder {
private ComputationGraph origGraph;
private ComputationGraphConfiguration origConfig;
private FineTuneConfiguration fineTuneConfiguration;
private ComputationGraphConfiguration.GraphBuilder editedConfigBuilder;
private String[] frozenOutputAt;
private boolean hasFrozen = false;
private Set editedVertices = new HashSet<>();
/**
* Computation Graph to tweak for transfer learning
* @param origGraph
*/
public GraphBuilder(ComputationGraph origGraph) {
this.origGraph = origGraph;
this.origConfig = origGraph.getConfiguration().clone();
}
/**
* Set parameters to selectively override existing learning parameters
* Usage eg. specify a lower learning rate. This will get applied to all layers
* @param fineTuneConfiguration
* @return GraphBuilder
*/
public GraphBuilder fineTuneConfiguration(FineTuneConfiguration fineTuneConfiguration) {
this.fineTuneConfiguration = fineTuneConfiguration;
this.editedConfigBuilder = new ComputationGraphConfiguration.GraphBuilder(origConfig,
fineTuneConfiguration.appliedNeuralNetConfigurationBuilder());
Map vertices = this.editedConfigBuilder.getVertices();
for (Map.Entry gv : vertices.entrySet()) {
if (gv.getValue() instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) gv.getValue();
NeuralNetConfiguration nnc = lv.getLayerConf().clone();
fineTuneConfiguration.applyToNeuralNetConfiguration(nnc);
vertices.put(gv.getKey(), new LayerVertex(nnc, lv.getPreProcessor()));
nnc.getLayer().setLayerName(gv.getKey());
}
}
return this;
}
/**
* Specify a layer vertex to set as a "feature extractor"
* The specified layer vertex and the layers on the path from an input vertex to it it will be "frozen" with parameters staying constant
* @param layerName
* @return Builder
*/
public GraphBuilder setFeatureExtractor(String... layerName) {
this.hasFrozen = true;
this.frozenOutputAt = layerName;
return this;
}
/**
* Modify the architecture of a vertex layer by changing nOut
* Note this will also affect the vertex layer that follows the layer specified, unless it is the output layer
* Currently does not support modifying nOut of layers that feed into non-layer vertices like merge, subset etc
* To modify nOut for such vertices use remove vertex, followed by add vertex
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerName The name of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param scheme Weight init scheme to use for params in layerName and the layers following it
* @return GraphBuilder
* @see org.deeplearning4j.nn.weights.WeightInit DISTRIBUTION
*/
public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme) {
return nOutReplace(layerName, nOut, scheme, scheme, null, null);
}
/**
* Modify the architecture of a vertex layer by changing nOut
* Note this will also affect the vertex layer that follows the layer specified, unless it is the output layer
* Currently does not support modifying nOut of layers that feed into non-layer vertices like merge, subset etc
* To modify nOut for such vertices use remove vertex, followed by add vertex
* Can specify different weight init schemes for the specified layer and the layer that follows it.
*
* @param layerName The name of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Weight distribution scheme to use
* @return GraphBuilder
* @see org.deeplearning4j.nn.weights.WeightInit DISTRIBUTION
*/
public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist) {
return nOutReplace(layerName, nOut, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, dist, dist);
}
/**
* Modified nOut of specified layer. Also affects layers following layerName unless they are output layers
* @param layerName The name of the layer to change nOut of
* @param nOut Value of nOut to change to
* @param dist Weight distribution scheme to use for layerName
* @param distNext Weight distribution scheme for layers following layerName
* @return GraphBuilder
* @see org.deeplearning4j.nn.weights.WeightInit DISTRIBUTION
*/
public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist, Distribution distNext) {
return nOutReplace(layerName, nOut, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, dist, distNext);
}
public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, Distribution dist) {
return nOutReplace(layerName, nOut, scheme, WeightInit.DISTRIBUTION, null, dist);
}
public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist, WeightInit scheme) {
return nOutReplace(layerName, nOut, WeightInit.DISTRIBUTION, scheme, dist, null);
}
public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, WeightInit schemeNext) {
return nOutReplace(layerName, nOut, scheme, schemeNext, null, null);
}
private GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, WeightInit schemeNext,
Distribution dist, Distribution distNext) {
initBuilderIfReq();
if (origGraph.getVertex(layerName).hasLayer()) {
NeuralNetConfiguration layerConf = origGraph.getLayer(layerName).conf();
Layer layerImpl = layerConf.getLayer().clone();
layerImpl.resetLayerDefaultConfig();
layerImpl.setWeightInit(scheme);
layerImpl.setDist(dist);
FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setNOut(nOut);
editedConfigBuilder.removeVertex(layerName, false);
LayerVertex lv = (LayerVertex) origConfig.getVertices().get(layerName);
String[] lvInputs = origConfig.getVertexInputs().get(layerName).toArray(new String[0]);
editedConfigBuilder.addLayer(layerName, layerImpl, lv.getPreProcessor(), lvInputs);
editedVertices.add(layerName);
//collect other vertices that have this vertex as inputs
List fanoutVertices = new ArrayList<>();
for (Map.Entry> entry : origConfig.getVertexInputs().entrySet()) {
String currentVertex = entry.getKey();
if (!currentVertex.equals(layerName)) {
if (entry.getValue().contains(layerName)) {
fanoutVertices.add(currentVertex);
}
}
}
//change nIn of fanout
for (String fanoutVertexName : fanoutVertices) {
if (!origGraph.getVertex(fanoutVertexName).hasLayer()) {
throw new UnsupportedOperationException(
"Cannot modify nOut of a layer vertex that feeds non-layer vertices. Use removeVertexKeepConnections followed by addVertex instead");
}
layerConf = origGraph.getLayer(fanoutVertexName).conf();
layerImpl = layerConf.getLayer().clone();
layerImpl.setWeightInit(schemeNext);
layerImpl.setDist(distNext);
layerImplF = (FeedForwardLayer) layerImpl;
layerImplF.setNIn(nOut);
editedConfigBuilder.removeVertex(fanoutVertexName, false);
lv = (LayerVertex) origConfig.getVertices().get(fanoutVertexName);
lvInputs = origConfig.getVertexInputs().get(fanoutVertexName).toArray(new String[0]);
editedConfigBuilder.addLayer(fanoutVertexName, layerImpl, lv.getPreProcessor(), lvInputs);
editedVertices.add(fanoutVertexName);
}
} else {
throw new IllegalArgumentException("noutReplace can only be applied to layer vertices. " + layerName
+ " is not a layer vertex");
}
return this;
}
/**
* Remove the specified vertex from the computation graph but keep it's connections.
* Note the expectation here is to then add back another vertex with the same name or else the graph will be left in an invalid state
* Possibly with references to vertices that no longer exist
* @param outputName
* @return
*/
public GraphBuilder removeVertexKeepConnections(String outputName) {
initBuilderIfReq();
editedConfigBuilder.removeVertex(outputName, false);
return this;
}
/**
* Remove specified vertex and it's connections from the computation graph
* @param vertexName
* @return
*/
public GraphBuilder removeVertexAndConnections(String vertexName) {
initBuilderIfReq();
editedConfigBuilder.removeVertex(vertexName, true);
return this;
}
/**
* Add a layer of the specified configuration to the computation graph
* @param layerName
* @param layer
* @param layerInputs
* @return
*/
public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) {
initBuilderIfReq();
editedConfigBuilder.addLayer(layerName, layer, null, layerInputs);
editedVertices.add(layerName);
return this;
}
/**
* Add a layer with a specified preprocessor
* @param layerName
* @param layer
* @param preProcessor
* @param layerInputs
* @return
*/
public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor,
String... layerInputs) {
initBuilderIfReq();
editedConfigBuilder.addLayer(layerName, layer, preProcessor, layerInputs);
editedVertices.add(layerName);
return this;
}
/**
* Add a vertex of the given configuration to the computation graph
* @param vertexName
* @param vertex
* @param vertexInputs
* @return
*/
public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) {
initBuilderIfReq();
editedConfigBuilder.addVertex(vertexName, vertex, vertexInputs);
editedVertices.add(vertexName);
return this;
}
/**
* Set outputs to the computation graph, will add to ones that are existing
* Also determines the order, like in ComputationGraphConfiguration
* @param outputNames
* @return
*/
public GraphBuilder setOutputs(String... outputNames) {
initBuilderIfReq();
editedConfigBuilder.setOutputs(outputNames);
return this;
}
private void initBuilderIfReq() {
if (editedConfigBuilder == null) {
//No fine tune config has been set. One isn't required, but we need one to create the editedConfigBuilder
//So: create an empty finetune config, which won't override anything
//but keep the seed
fineTuneConfiguration(new FineTuneConfiguration.Builder()
.seed(origConfig.getDefaultConfiguration().getSeed()).build());
}
}
protected GraphBuilder addInputs(String... inputNames) {
editedConfigBuilder.addInputs(inputNames);
return this;
}
/**
* Returns a computation graph build to specifications.
* Init has been internally called. Can be fit directly.
* @return Computation graph
*/
public ComputationGraph build() {
initBuilderIfReq();
ComputationGraphConfiguration newConfig = editedConfigBuilder.build();
ComputationGraph newGraph = new ComputationGraph(newConfig);
newGraph.init();
int[] topologicalOrder = newGraph.topologicalSortOrder();
org.deeplearning4j.nn.graph.vertex.GraphVertex[] vertices = newGraph.getVertices();
if (!editedVertices.isEmpty()) {
//set params from orig graph as necessary to new graph
for (int i = 0; i < topologicalOrder.length; i++) {
if (!vertices[topologicalOrder[i]].hasLayer())
continue;
org.deeplearning4j.nn.api.Layer layer = vertices[topologicalOrder[i]].getLayer();
String layerName = vertices[topologicalOrder[i]].getVertexName();
int range = layer.numParams();
if (range <= 0)
continue; //some layers have no params
if (editedVertices.contains(layerName))
continue; //keep the changed params
layer.setParams(origGraph.getLayer(layerName).params().dup()); //copy over origGraph params
}
} else {
newGraph.setParams(origGraph.params());
}
//Freeze layers as necessary. Note: we can't simply say "everything before frozen layer X needs to be frozen
// also" as this won't always work. For example, in1->A->C, in2->B->C, freeze B; A shouldn't be frozen, even
// if A is before B in the topological sort order.
//How it should be handled: use the graph structure + topological sort order.
// If a vertex is marked to be frozen: freeze it
// Any descendants of a frozen layer should also be frozen
if (hasFrozen) {
//Store all frozen layers, and any vertices inheriting from said layers
Set allFrozen = new HashSet<>();
Collections.addAll(allFrozen, frozenOutputAt);
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
org.deeplearning4j.nn.graph.vertex.GraphVertex gv = vertices[topologicalOrder[i]];
if (allFrozen.contains(gv.getVertexName())) {
if (gv.hasLayer()) {
//Need to freeze this layer
org.deeplearning4j.nn.api.Layer l = gv.getLayer();
gv.setLayerAsFrozen();
//We also need to place the layer in the CompGraph Layer[] (replacing the old one)
//This could no doubt be done more efficiently
org.deeplearning4j.nn.api.Layer[] layers = newGraph.getLayers();
for (int j = 0; j < layers.length; j++) {
if (layers[j] == l) {
layers[j] = gv.getLayer(); //Place the new frozen layer to replace the original layer
break;
}
}
}
//Also: mark any inputs as to be frozen also
VertexIndices[] inputs = gv.getInputVertices();
if (inputs != null && inputs.length > 0) {
for (int j = 0; j < inputs.length; j++) {
int inputVertexIdx = inputs[j].getVertexIndex();
String alsoFreeze = vertices[inputVertexIdx].getVertexName();
allFrozen.add(alsoFreeze);
}
}
}
}
newGraph.initGradientsView();
}
return newGraph;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy