All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.nn.layers.recurrent;

import lombok.AllArgsConstructor;
import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.params.BidirectionalParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

import java.util.*;

import static org.nd4j.linalg.indexing.NDArrayIndex.*;

/**
 * Bidirectional is a "wrapper" layer: it wraps any uni-directional RNN layer to make it bidirectional.
* Note that multiple different modes are supported - these specify how the activations should be combined from * the forward and backward RNN networks. See {@link Bidirectional.Mode} javadoc for more details.
* Parameters are not shared here - there are 2 separate copies of the wrapped RNN layer, each with separate parameters. *
* Usage: {@code .layer(new Bidirectional(new LSTM.Builder()....build())} * * @author Alex Black */ public class BidirectionalLayer implements RecurrentLayer { private NeuralNetConfiguration conf; private Layer fwd; private Layer bwd; private Bidirectional layerConf; private INDArray paramsView; private INDArray gradientView; private transient Map gradientViews; private INDArray input; //Next 2 variables: used *only* for MUL case (needed for backprop) private INDArray outFwd; private INDArray outBwd; public BidirectionalLayer(@NonNull NeuralNetConfiguration conf, @NonNull Layer fwd, @NonNull Layer bwd, @NonNull INDArray paramsView) { this.conf = conf; this.fwd = fwd; this.bwd = bwd; this.layerConf = (Bidirectional) conf.getLayer(); this.paramsView = paramsView; } @Override public INDArray rnnTimeStep(INDArray input, LayerWorkspaceMgr workspaceMgr) { throw new UnsupportedOperationException("Cannot RnnTimeStep bidirectional layers"); } @Override public Map rnnGetPreviousState() { throw new UnsupportedOperationException("Not supported: cannot RnnTimeStep bidirectional layers therefore " + "no previous state is supported"); } @Override public void rnnSetPreviousState(Map stateMap) { throw new UnsupportedOperationException("Not supported: cannot RnnTimeStep bidirectional layers therefore " + "no previous state is supported"); } @Override public void rnnClearPreviousState() { //No op } @Override public INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT, LayerWorkspaceMgr workspaceMgr) { throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers"); } @Override public Map rnnGetTBPTTState() { throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers"); } @Override public void rnnSetTBPTTState(Map state) { throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers"); } @Override public Pair tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr) { throw new UnsupportedOperationException("Not supported: cannot use this method (or truncated BPTT) with bidirectional layers"); } @Override public void setCacheMode(CacheMode mode) { fwd.setCacheMode(mode); bwd.setCacheMode(mode); } @Override public double calcRegularizationScore(boolean backpropParamsOnly){ return fwd.calcRegularizationScore(backpropParamsOnly) + bwd.calcRegularizationScore(backpropParamsOnly); } @Override public Type type() { return Type.RECURRENT; } @Override public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { INDArray eFwd; INDArray eBwd; val n = epsilon.size(1)/2; switch (layerConf.getMode()){ case ADD: eFwd = epsilon; eBwd = epsilon; break; case MUL: eFwd = epsilon.dup(epsilon.ordering()).muli(outBwd); eBwd = epsilon.dup(epsilon.ordering()).muli(outFwd); break; case AVERAGE: eFwd = epsilon.dup(epsilon.ordering()).muli(0.5); eBwd = eFwd; break; case CONCAT: eFwd = epsilon.get(all(), interval(0,n), all()); eBwd = epsilon.get(all(), interval(n, 2*n), all()); break; default: throw new RuntimeException("Unknown mode: " + layerConf.getMode()); } eBwd = TimeSeriesUtils.reverseTimeSeries(eBwd, workspaceMgr, ArrayType.BP_WORKING_MEM); Pair g1 = fwd.backpropGradient(eFwd, workspaceMgr); Pair g2 = bwd.backpropGradient(eBwd, workspaceMgr); Gradient g = new DefaultGradient(gradientView); for(Map.Entry e : g1.getFirst().gradientForVariable().entrySet()){ g.gradientForVariable().put(BidirectionalParamInitializer.FORWARD_PREFIX + e.getKey(), e.getValue()); } for(Map.Entry e : g2.getFirst().gradientForVariable().entrySet()){ g.gradientForVariable().put(BidirectionalParamInitializer.BACKWARD_PREFIX + e.getKey(), e.getValue()); } INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries(g2.getRight(), workspaceMgr, ArrayType.BP_WORKING_MEM); INDArray epsOut = g1.getRight().addi(g2Reversed); return new Pair<>(g, epsOut); } @Override public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { INDArray out1 = fwd.activate(training, workspaceMgr); INDArray out2 = bwd.activate(training, workspaceMgr); //Reverse the output time series. Note: when using LastTimeStepLayer, output can be rank 2 out2 = out2.rank() == 2 ? out2 : TimeSeriesUtils.reverseTimeSeries(out2, workspaceMgr, ArrayType.FF_WORKING_MEM); switch (layerConf.getMode()){ case ADD: return out1.addi(out2); case MUL: //TODO may be more efficient ways than this... this.outFwd = out1.detach(); this.outBwd = out2.detach(); return workspaceMgr.dup(ArrayType.ACTIVATIONS, out1).muli(out2); case AVERAGE: return out1.addi(out2).muli(0.5); case CONCAT: INDArray ret = Nd4j.concat(1, out1, out2); return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); default: throw new RuntimeException("Unknown mode: " + layerConf.getMode()); } } @Override public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) { setInput(input, workspaceMgr); return activate(training, workspaceMgr); } @Override public Collection getListeners() { return fwd.getListeners(); } @Override public void setListeners(TrainingListener... listeners) { fwd.setListeners(listeners); bwd.setListeners(listeners); } @Override public void addListeners(TrainingListener... listener) { fwd.addListeners(listener); bwd.addListeners(listener); } @Override public void fit() { throw new UnsupportedOperationException("Not supported"); } @Override public void update(Gradient gradient) { throw new UnsupportedOperationException("Not supported"); } @Override public void update(INDArray gradient, String paramType) { throw new UnsupportedOperationException("Not supported"); } @Override public double score() { return fwd.score() + bwd.score(); } @Override public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { fwd.computeGradientAndScore(workspaceMgr); bwd.computeGradientAndScore(workspaceMgr); } @Override public INDArray params() { return paramsView; } @Override public TrainingConfig getConfig() { return conf.getLayer(); } @Override public long numParams() { return fwd.numParams() + bwd.numParams(); } @Override public long numParams(boolean backwards) { return fwd.numParams(backwards) + bwd.numParams(backwards); } @Override public void setParams(INDArray params) { this.paramsView.assign(params); } @Override public void setParamsViewArray(INDArray params) { this.paramsView = params; val n = params.length(); fwd.setParamsViewArray(params.get(interval(0, 0, true), interval(0, n))); bwd.setParamsViewArray(params.get(interval(0, 0, true), interval(n, 2*n))); } @Override public INDArray getGradientsViewArray() { return gradientView; } @Override public void setBackpropGradientsViewArray(INDArray gradients) { if (this.paramsView != null && gradients.length() != numParams()) throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true) + ", got array of length " + gradients.length()); this.gradientView = gradients; val n = gradients.length() / 2; INDArray g1 = gradients.get(interval(0, 0, true), interval(0,n)); INDArray g2 = gradients.get(interval(0, 0, true), interval(n, 2*n)); fwd.setBackpropGradientsViewArray(g1); bwd.setBackpropGradientsViewArray(g2); } @Override public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) { throw new UnsupportedOperationException("Not supported"); } @Override public Gradient gradient() { throw new UnsupportedOperationException("Not supported"); } @Override public Pair gradientAndScore() { throw new UnsupportedOperationException("Not supported"); } @Override public int batchSize() { return fwd.batchSize(); } @Override public NeuralNetConfiguration conf() { return conf; } @Override public void setConf(NeuralNetConfiguration conf) { this.conf = conf; } @Override public INDArray input() { return input; } @Override public ConvexOptimizer getOptimizer() { return null; } @Override public INDArray getParam(String param) { String sub = param.substring(1); if(param.startsWith(BidirectionalParamInitializer.FORWARD_PREFIX)){ return fwd.getParam(sub); } else { return bwd.getParam(sub); } } @Override public Map paramTable() { return paramTable(false); } @Override public Map paramTable(boolean backpropParamsOnly) { Map m = new LinkedHashMap<>(); for(Map.Entry e : fwd.paramTable(backpropParamsOnly).entrySet()){ m.put(BidirectionalParamInitializer.FORWARD_PREFIX + e.getKey(), e.getValue()); } for(Map.Entry e : bwd.paramTable(backpropParamsOnly).entrySet()){ m.put(BidirectionalParamInitializer.BACKWARD_PREFIX + e.getKey(), e.getValue()); } return m; } @Override public boolean updaterDivideByMinibatch(String paramName) { String sub = paramName.substring(1); if(paramName.startsWith(BidirectionalParamInitializer.FORWARD_PREFIX)){ return fwd.updaterDivideByMinibatch(paramName); } else { return bwd.updaterDivideByMinibatch(paramName); } } @Override public void setParamTable(Map paramTable) { for(Map.Entry e : paramTable.entrySet()){ setParam(e.getKey(), e.getValue()); } } @Override public void setParam(String key, INDArray val) { String sub = key.substring(1); if(key.startsWith(BidirectionalParamInitializer.FORWARD_PREFIX)){ fwd.setParam(sub, val); } else { bwd.setParam(sub, val); } } @Override public void clear() { fwd.clear(); bwd.clear(); input = null; outFwd = null; outBwd = null; } @Override public void applyConstraints(int iteration, int epoch) { fwd.applyConstraints(iteration, epoch); bwd.applyConstraints(iteration, epoch); } @Override public void init() { //No op } @Override public void setListeners(Collection listeners) { fwd.setListeners(listeners); bwd.setListeners(listeners); } @Override public void setIndex(int index) { fwd.setIndex(index); bwd.setIndex(index); } @Override public int getIndex() { return fwd.getIndex(); } @Override public int getIterationCount() { return fwd.getIterationCount(); } @Override public int getEpochCount() { return fwd.getEpochCount(); } @Override public void setIterationCount(int iterationCount) { fwd.setIterationCount(iterationCount); bwd.setIterationCount(iterationCount); } @Override public void setEpochCount(int epochCount) { fwd.setEpochCount(epochCount); bwd.setEpochCount(epochCount); } @Override public void setInput(INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) { this.input = input; fwd.setInput(input, layerWorkspaceMgr); INDArray reversed; if(!input.isAttached()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { reversed = TimeSeriesUtils.reverseTimeSeries(input); } } else { MemoryWorkspace ws = input.data().getParentWorkspace(); try(MemoryWorkspace ws2 = ws.notifyScopeBorrowed()){ //Put the reversed input into the same workspace as the original input reversed = TimeSeriesUtils.reverseTimeSeries(input); } } bwd.setInput(reversed, layerWorkspaceMgr); } @Override public void setInputMiniBatchSize(int size) { fwd.setInputMiniBatchSize(size); bwd.setInputMiniBatchSize(size); } @Override public int getInputMiniBatchSize() { return fwd.getInputMiniBatchSize(); } @Override public void setMaskArray(INDArray maskArray) { fwd.setMaskArray(maskArray); bwd.setMaskArray(TimeSeriesUtils.reverseTimeSeriesMask(maskArray, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); //TODO } @Override public INDArray getMaskArray() { return fwd.getMaskArray(); } @Override public boolean isPretrainLayer() { return fwd.isPretrainLayer(); } @Override public void clearNoiseWeightParams() { fwd.clearNoiseWeightParams(); bwd.clearNoiseWeightParams(); } @Override public void allowInputModification(boolean allow) { fwd.allowInputModification(allow); bwd.allowInputModification(true); //Always allow: always safe due to reverse op } @Override public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { Pair ret = fwd.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); bwd.feedForwardMaskArray(TimeSeriesUtils.reverseTimeSeriesMask(maskArray, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT), //TODO currentMaskState, minibatchSize); return ret; } @Override public LayerHelper getHelper() { LayerHelper f = fwd.getHelper(); LayerHelper b = bwd.getHelper(); if(f != null || b != null){ return new BidirectionalHelper(f,b); } return null; } @AllArgsConstructor private static class BidirectionalHelper implements LayerHelper { private final LayerHelper helperFwd; private final LayerHelper helperBwd; @Override public Map helperMemoryUse() { Map fwd = (helperFwd != null ? helperFwd.helperMemoryUse() : null); Map bwd = (helperBwd != null ? helperBwd.helperMemoryUse() : null); Set keys = new HashSet<>(); if(fwd != null) keys.addAll(fwd.keySet()); if(bwd != null) keys.addAll(bwd.keySet()); Map ret = new HashMap<>(); for(String s : keys){ long sum = 0; if(fwd != null && fwd.containsKey(s)){ sum += fwd.get(s); } if(bwd != null && bwd.containsKey(s)){ sum += bwd.get(s); } ret.put(s, sum); } return ret; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy