org.deeplearning4j.nn.transferlearning.TransferLearningHelper Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.transferlearning;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.util.*;
public class TransferLearningHelper {
private boolean isGraph = true;
private boolean applyFrozen = false;
private ComputationGraph origGraph;
private MultiLayerNetwork origMLN;
private int frozenTill;
private String[] frozenOutputAt;
private ComputationGraph unFrozenSubsetGraph;
private MultiLayerNetwork unFrozenSubsetMLN;
Set frozenInputVertices = new HashSet<>(); //name map so no problem
List graphInputs;
int frozenInputLayer = 0;
/**
* Will modify the given comp graph (in place!) to freeze vertices from input to the vertex specified.
*
* @param orig Comp graph
* @param frozenOutputAt vertex to freeze at (hold params constant during training)
*/
public TransferLearningHelper(ComputationGraph orig, String... frozenOutputAt) {
origGraph = orig;
this.frozenOutputAt = frozenOutputAt;
applyFrozen = true;
initHelperGraph();
}
/**
* Expects a computation graph where some vertices are frozen
*
* @param orig
*/
public TransferLearningHelper(ComputationGraph orig) {
origGraph = orig;
initHelperGraph();
}
/**
* Will modify the given MLN (in place!) to freeze layers (hold params constant during training) specified and below
*
* @param orig MLN to freeze
* @param frozenTill integer indicating the index of the layer and below to freeze
*/
public TransferLearningHelper(MultiLayerNetwork orig, int frozenTill) {
isGraph = false;
this.frozenTill = frozenTill;
applyFrozen = true;
origMLN = orig;
initHelperMLN();
}
/**
* Expects a MLN where some layers are frozen
*
* @param orig
*/
public TransferLearningHelper(MultiLayerNetwork orig) {
isGraph = false;
origMLN = orig;
initHelperMLN();
}
public void errorIfGraphIfMLN() {
if (isGraph)
throw new IllegalArgumentException(
"This instance was initialized with a computation graph. Cannot apply methods related to MLN");
else
throw new IllegalArgumentException(
"This instance was initialized with a MultiLayerNetwork. Cannot apply methods related to computation graphs");
}
/**
* Returns the unfrozen subset of the original computation graph as a computation graph
* Note that with each call to featurizedFit the parameters to the original computation graph are also updated
*/
public ComputationGraph unfrozenGraph() {
if (!isGraph)
errorIfGraphIfMLN();
return unFrozenSubsetGraph;
}
/**
* Returns the unfrozen layers of the MultiLayerNetwork as a multilayernetwork
* Note that with each call to featurizedFit the parameters to the original MLN are also updated
*/
public MultiLayerNetwork unfrozenMLN() {
if (isGraph)
errorIfGraphIfMLN();
return unFrozenSubsetMLN;
}
/**
* Use to get the output from a featurized input
*
* @param input featurized data
* @return output
*/
public INDArray[] outputFromFeaturized(INDArray[] input) {
if (!isGraph)
errorIfGraphIfMLN();
return unFrozenSubsetGraph.output(input);
}
/**
* Use to get the output from a featurized input
*
* @param input featurized data
* @return output
*/
public INDArray outputFromFeaturized(INDArray input) {
if (isGraph) {
if (unFrozenSubsetGraph.getNumOutputArrays() > 1) {
throw new IllegalArgumentException(
"Graph has more than one output. Expecting an input array with outputFromFeaturized method call");
}
return unFrozenSubsetGraph.output(input)[0];
} else {
return unFrozenSubsetMLN.output(input);
}
}
/**
* Runs through the comp graph and saves off a new model that is simply the "unfrozen" part of the origModel
* This "unfrozen" model is then used for training with featurized data
*/
private void initHelperGraph() {
int[] backPropOrder = origGraph.topologicalSortOrder().clone();
ArrayUtils.reverse(backPropOrder);
Set allFrozen = new HashSet<>();
if (applyFrozen) {
Collections.addAll(allFrozen, frozenOutputAt);
}
for (int i = 0; i < backPropOrder.length; i++) {
GraphVertex gv = origGraph.getVertices()[backPropOrder[i]];
if (applyFrozen && allFrozen.contains(gv.getVertexName())) {
if (gv.hasLayer()) {
//Need to freeze this layer
org.deeplearning4j.nn.api.Layer l = gv.getLayer();
gv.setLayerAsFrozen();
//We also need to place the layer in the CompGraph Layer[] (replacing the old one)
//This could no doubt be done more efficiently
org.deeplearning4j.nn.api.Layer[] layers = origGraph.getLayers();
for (int j = 0; j < layers.length; j++) {
if (layers[j] == l) {
layers[j] = gv.getLayer(); //Place the new frozen layer to replace the original layer
break;
}
}
}
//Also: mark any inputs as to be frozen also
VertexIndices[] inputs = gv.getInputVertices();
if (inputs != null && inputs.length > 0) {
for (int j = 0; j < inputs.length; j++) {
int inputVertexIdx = inputs[j].getVertexIndex();
String alsoFreeze = origGraph.getVertices()[inputVertexIdx].getVertexName();
allFrozen.add(alsoFreeze);
}
}
} else {
if (gv.hasLayer()) {
if (gv.getLayer() instanceof FrozenLayer) {
allFrozen.add(gv.getVertexName());
//also need to add parents to list of allFrozen
VertexIndices[] inputs = gv.getInputVertices();
if (inputs != null && inputs.length > 0) {
for (int j = 0; j < inputs.length; j++) {
int inputVertexIdx = inputs[j].getVertexIndex();
String alsoFrozen = origGraph.getVertices()[inputVertexIdx].getVertexName();
allFrozen.add(alsoFrozen);
}
}
}
}
}
}
for (int i = 0; i < backPropOrder.length; i++) {
GraphVertex gv = origGraph.getVertices()[backPropOrder[i]];
String gvName = gv.getVertexName();
//is it an unfrozen vertex that has an input vertex that is frozen?
if (!allFrozen.contains(gvName) && !gv.isInputVertex()) {
VertexIndices[] inputs = gv.getInputVertices();
for (int j = 0; j < inputs.length; j++) {
int inputVertexIdx = inputs[j].getVertexIndex();
String inputVertex = origGraph.getVertices()[inputVertexIdx].getVertexName();
if (allFrozen.contains(inputVertex)) {
frozenInputVertices.add(inputVertex);
}
}
}
}
TransferLearning.GraphBuilder builder = new TransferLearning.GraphBuilder(origGraph);
for (String toRemove : allFrozen) {
if (frozenInputVertices.contains(toRemove)) {
builder.removeVertexKeepConnections(toRemove);
} else {
builder.removeVertexAndConnections(toRemove);
}
}
Set frozenInputVerticesSorted = new HashSet<>();
frozenInputVerticesSorted.addAll(origGraph.getConfiguration().getNetworkInputs());
frozenInputVerticesSorted.removeAll(allFrozen);
//remove input vertices - just to add back in a predictable order
for (String existingInput : frozenInputVerticesSorted) {
builder.removeVertexKeepConnections(existingInput);
}
frozenInputVerticesSorted.addAll(frozenInputVertices);
//Sort all inputs to the computation graph - in order to have a predictable order
graphInputs = new ArrayList(frozenInputVerticesSorted);
Collections.sort(graphInputs);
for (String asInput : frozenInputVerticesSorted) {
//add back in the right order
builder.addInputs(asInput);
}
unFrozenSubsetGraph = builder.build();
copyOrigParamsToSubsetGraph();
//unFrozenSubsetGraph.setListeners(origGraph.getListeners());
if (frozenInputVertices.isEmpty()) {
throw new IllegalArgumentException("No frozen layers found");
}
}
private void initHelperMLN() {
if (applyFrozen) {
org.deeplearning4j.nn.api.Layer[] layers = origMLN.getLayers();
for (int i = frozenTill; i >= 0; i--) {
//unchecked?
layers[i] = new FrozenLayer(layers[i]);
}
origMLN.setLayers(layers);
}
for (int i = 0; i < origMLN.getnLayers(); i++) {
if (origMLN.getLayer(i) instanceof FrozenLayer) {
frozenInputLayer = i;
}
}
List allConfs = new ArrayList<>();
for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) {
allConfs.add(origMLN.getLayer(i).conf());
}
MultiLayerConfiguration c = origMLN.getLayerWiseConfigurations();
unFrozenSubsetMLN = new MultiLayerNetwork(new MultiLayerConfiguration.Builder()
.inputPreProcessors(c.getInputPreProcessors())
.backpropType(c.getBackpropType()).tBPTTForwardLength(c.getTbpttFwdLength())
.tBPTTBackwardLength(c.getTbpttBackLength()).confs(allConfs)
.dataType(origMLN.getLayerWiseConfigurations().getDataType())
.build());
unFrozenSubsetMLN.init();
//copy over params
for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) {
unFrozenSubsetMLN.getLayer(i - frozenInputLayer - 1).setParams(origMLN.getLayer(i).params());
}
//unFrozenSubsetMLN.setListeners(origMLN.getListeners());
}
/**
* During training frozen vertices/layers can be treated as "featurizing" the input
* The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate
* quickly on the smaller unfrozen part of the model
* Currently does not support datasets with feature masks
*
* @param input multidataset to feed into the computation graph with frozen layer vertices
* @return a multidataset with input features that are the outputs of the frozen layer vertices and the original labels.
*/
public MultiDataSet featurize(MultiDataSet input) {
if (!isGraph) {
throw new IllegalArgumentException("Cannot use multidatasets with MultiLayerNetworks.");
}
INDArray[] labels = input.getLabels();
INDArray[] features = input.getFeatures();
if (input.getFeaturesMaskArrays() != null) {
throw new IllegalArgumentException("Currently cannot support featurizing datasets with feature masks");
}
INDArray[] featureMasks = null;
INDArray[] labelMasks = input.getLabelsMaskArrays();
INDArray[] featuresNow = new INDArray[graphInputs.size()];
Map activationsNow = origGraph.feedForward(features, false);
for (int i = 0; i < graphInputs.size(); i++) {
String anInput = graphInputs.get(i);
if (origGraph.getVertex(anInput).isInputVertex()) {
//was an original input to the graph
int inputIndex = origGraph.getConfiguration().getNetworkInputs().indexOf(anInput);
featuresNow[i] = origGraph.getInput(inputIndex);
} else {
//needs to be grabbed from the internal activations
featuresNow[i] = activationsNow.get(anInput);
}
}
return new MultiDataSet(featuresNow, labels, featureMasks, labelMasks);
}
/**
* During training frozen vertices/layers can be treated as "featurizing" the input
* The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate
* quickly on the smaller unfrozen part of the model
* Currently does not support datasets with feature masks
*
* @param input multidataset to feed into the computation graph with frozen layer vertices
* @return a multidataset with input features that are the outputs of the frozen layer vertices and the original labels.
*/
public DataSet featurize(DataSet input) {
if (isGraph) {
//trying to featurize for a computation graph
if (origGraph.getNumInputArrays() > 1 || origGraph.getNumOutputArrays() > 1) {
throw new IllegalArgumentException(
"Input or output size to a computation graph is greater than one. Requires use of a MultiDataSet.");
} else {
if (input.getFeaturesMaskArray() != null) {
throw new IllegalArgumentException(
"Currently cannot support featurizing datasets with feature masks");
}
MultiDataSet inbW = new MultiDataSet(new INDArray[] {input.getFeatures()},
new INDArray[] {input.getLabels()}, null, new INDArray[] {input.getLabelsMaskArray()});
MultiDataSet ret = featurize(inbW);
return new DataSet(ret.getFeatures()[0], input.getLabels(), ret.getLabelsMaskArrays()[0],
input.getLabelsMaskArray());
}
} else {
if (input.getFeaturesMaskArray() != null)
throw new UnsupportedOperationException("Feature masks not supported with featurizing currently");
return new DataSet(origMLN.feedForwardToLayer(frozenInputLayer + 1, input.getFeatures(), false)
.get(frozenInputLayer + 1), input.getLabels(), null, input.getLabelsMaskArray());
}
}
/**
* Fit from a featurized dataset.
* The fit is conducted on an internally instantiated subset model that is representative of the unfrozen part of the original model.
* After each call on fit the parameters for the original model are updated
*
* @param iter
*/
public void fitFeaturized(MultiDataSetIterator iter) {
unFrozenSubsetGraph.fit(iter);
copyParamsFromSubsetGraphToOrig();
}
public void fitFeaturized(MultiDataSet input) {
unFrozenSubsetGraph.fit(input);
copyParamsFromSubsetGraphToOrig();
}
public void fitFeaturized(DataSet input) {
if (isGraph) {
unFrozenSubsetGraph.fit(input);
copyParamsFromSubsetGraphToOrig();
} else {
unFrozenSubsetMLN.fit(input);
copyParamsFromSubsetMLNToOrig();
}
}
public void fitFeaturized(DataSetIterator iter) {
if (isGraph) {
unFrozenSubsetGraph.fit(iter);
copyParamsFromSubsetGraphToOrig();
} else {
unFrozenSubsetMLN.fit(iter);
copyParamsFromSubsetMLNToOrig();
}
}
private void copyParamsFromSubsetGraphToOrig() {
for (GraphVertex aVertex : unFrozenSubsetGraph.getVertices()) {
if (!aVertex.hasLayer())
continue;
origGraph.getVertex(aVertex.getVertexName()).getLayer().setParams(aVertex.getLayer().params());
}
}
private void copyOrigParamsToSubsetGraph() {
for (GraphVertex aVertex : unFrozenSubsetGraph.getVertices()) {
if (!aVertex.hasLayer())
continue;
aVertex.getLayer().setParams(origGraph.getLayer(aVertex.getVertexName()).params());
}
}
private void copyParamsFromSubsetMLNToOrig() {
for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) {
origMLN.getLayer(i).setParams(unFrozenSubsetMLN.getLayer(i - frozenInputLayer - 1).params());
}
}
}