All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.layers.recurrent.SimpleRnn Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.layers.recurrent;

import lombok.val;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Quad;

import static org.nd4j.linalg.indexing.NDArrayIndex.*;

public class SimpleRnn extends BaseRecurrentLayer {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";


    public SimpleRnn(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public INDArray rnnTimeStep(INDArray input, LayerWorkspaceMgr workspaceMgr) {
        setInput(input, workspaceMgr);
        INDArray last = stateMap.get(STATE_KEY_PREV_ACTIVATION);
        INDArray out = activateHelper(last, false, false, workspaceMgr).getFirst();
        try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
            stateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)).dup());
        }
        return out;
    }

    @Override
    public INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT, LayerWorkspaceMgr workspaceMgr) {
        setInput(input, workspaceMgr);
        INDArray last = tBpttStateMap.get(STATE_KEY_PREV_ACTIVATION);
        INDArray out = activateHelper(last, training, false, workspaceMgr).getFirst();
        if(storeLastForTBPTT){
            try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
                tBpttStateMap.put(STATE_KEY_PREV_ACTIVATION, out.get(all(), all(), point(out.size(2)-1)).dup());
            }
        }
        return out;
    }

    @Override
    public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        return tbpttBackpropGradient(epsilon, -1, workspaceMgr);
    }

    @Override
    public Pair tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr) {
        assertInputSet(true);
        if(epsilon.ordering() != 'f' || !Shape.hasDefaultStridesForShape(epsilon))
            epsilon = epsilon.dup('f');

        val nOut = layerConf().getNOut();

        INDArray input = this.input.castTo(dataType);   //No-op if correct type
        input = permuteIfNWC(input);

        //First: Do forward pass to get gate activations and Zs
        Quad p = activateHelper(null, true, true, workspaceMgr);

        INDArray w = getParamWithNoise(SimpleRnnParamInitializer.WEIGHT_KEY, true, workspaceMgr);
        INDArray rw = getParamWithNoise(SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY, true, workspaceMgr);
        INDArray b = getParamWithNoise(SimpleRnnParamInitializer.BIAS_KEY, true, workspaceMgr);
        INDArray g = (hasLayerNorm() ? getParamWithNoise(SimpleRnnParamInitializer.GAIN_KEY, true, workspaceMgr) : null);
        INDArray gx = (g != null ? g.get(interval(0, 0, true), interval(0, nOut)) : null);
        INDArray gr = (g != null ? g.get(interval(0, 0, true), interval(nOut, nOut * 2)) : null);

        INDArray wg = gradientViews.get(SimpleRnnParamInitializer.WEIGHT_KEY);
        INDArray rwg = gradientViews.get(SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY);
        INDArray bg = gradientViews.get(SimpleRnnParamInitializer.BIAS_KEY);
        INDArray gg = (hasLayerNorm() ? gradientViews.get(SimpleRnnParamInitializer.GAIN_KEY) : null);
        INDArray gxg = (gg != null ? gg.get(interval(0, 0, true), interval(0, nOut)) : null);
        INDArray grg = (gg != null ? gg.get(interval(0, 0, true), interval(nOut, nOut * 2)) : null);

        gradientsFlattened.assign(0);

        IActivation a = layerConf().getActivationFn();

        val tsLength = input.size(2);

        INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'f');

        INDArray dldzNext = null;
        long end;
        if(tbpttBackLength > 0){
            end = Math.max(0, tsLength-tbpttBackLength);
        } else {
            end = 0;
        }
        epsilon = permuteIfNWC(epsilon);
        for( long i = tsLength - 1; i >= end; i--) {
            INDArray dldaCurrent = epsilon.get(all(), all(), point(i)).dup();
            INDArray aCurrent = p.getFirst().get(all(), all(), point(i));
            INDArray zCurrent = p.getSecond().get(all(), all(), point(i));
            INDArray nCurrent = (hasLayerNorm() ? p.getThird().get(all(), all(), point(i)) : null);
            INDArray rCurrent = (hasLayerNorm() ? p.getFourth().get(all(), all(), point(i)) : null);
            INDArray inCurrent = input.get(all(), all(), point(i));
            INDArray epsOutCurrent = epsOut.get(all(), all(), point(i));

            if(dldzNext != null){
                //Backprop the component of dL/da (for current time step) from the recurrent connections
                Nd4j.gemm(dldzNext, rw, dldaCurrent, false, true, 1.0, 1.0);

                //Recurrent weight gradients:
                Nd4j.gemm(aCurrent, dldzNext, rwg, true, false, 1.0, 1.0);
            }
            INDArray dldzCurrent = a.backprop(zCurrent.dup(), dldaCurrent).getFirst();

            //Handle masking
            INDArray maskCol = null;
            if( maskArray != null) {
                //Mask array: shape [minibatch, tsLength]
                //If mask array is present (for example, with bidirectional RNN) -> need to zero out these errors to
                // avoid using errors from a masked time step to calculate the parameter gradients
                maskCol = maskArray.getColumn(i, true).castTo(dataType);
                dldzCurrent.muliColumnVector(maskCol);
            }

            INDArray dldnCurrent;
            if(hasLayerNorm()) {
                dldnCurrent = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
                INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), gxg.shape());
                INDArray bgCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, bg.dataType(), bg.shape());
                Nd4j.getExecutioner().exec(new LayerNormBp(nCurrent, gx, b, dldzCurrent, dldnCurrent, ggCur, bgCur, true, 1));
                gxg.addi(ggCur);
                bg.addi(bgCur);
            }else{
                dldnCurrent = dldzCurrent;
                //Bias gradients
                bg.addi(dldzCurrent.sum(0));
            }

            //weight gradients:
            Nd4j.gemm(inCurrent, dldnCurrent, wg, true, false, 1.0, 1.0);

            //Epsilon out to layer below (i.e., dL/dIn)
            Nd4j.gemm(dldnCurrent, w, epsOutCurrent, false, true, 1.0, 0.0);

            // propagate epsilon to previous iteration
            if(hasLayerNorm() && i > end){
                dldzNext = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape());
                INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), grg.shape());
                Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, true, 1));
                grg.addi(ggCur);
            }else{
                dldzNext = dldzCurrent;
            }

            if( maskArray != null){
                //If mask array is present: Also need to zero out errors to avoid sending anything but 0s to layer below for masked steps
                epsOutCurrent.muliColumnVector(maskCol);
            }
        }

        weightNoiseParams.clear();

        Gradient grad = new DefaultGradient(gradientsFlattened);
        grad.gradientForVariable().put(SimpleRnnParamInitializer.WEIGHT_KEY, wg);
        grad.gradientForVariable().put(SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY, rwg);
        grad.gradientForVariable().put(SimpleRnnParamInitializer.BIAS_KEY, bg);
        if(hasLayerNorm()){
            grad.gradientForVariable().put(SimpleRnnParamInitializer.GAIN_KEY, gg);
        }

        epsOut = backpropDropOutIfPresent(epsOut);
        epsOut = permuteIfNWC(epsOut);
        return new Pair<>(grad, epsOut);
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
        return activateHelper(null, training, false, workspaceMgr).getFirst();
    }

    private Quad activateHelper(INDArray prevStepOut, boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr){
        assertInputSet(false);
        Preconditions.checkState(input.rank() == 3,
                "3D input expected to RNN layer expected, got " + input.rank());
        Preconditions.checkState(prevStepOut == null || prevStepOut.size(0) == input.size(0),
                "Invalid RNN previous state (last time step activations/initialization): rnnTimeStep with different minibatch size, or forgot to call rnnClearPreviousState between batches?" +
                        " Previous step output = [batch, nIn] = %ndShape, current input = [batch, nIn, seqLength] = %ndShape", prevStepOut, input);

        applyDropOutIfNecessary(training, workspaceMgr);

        INDArray input = this.input.castTo(dataType);    //No-op if correct type
        input = permuteIfNWC(input);
        val m = input.size(0);
        val tsLength = input.size(2);
        val nOut = layerConf().getNOut();

        INDArray w = getParamWithNoise(SimpleRnnParamInitializer.WEIGHT_KEY, training, workspaceMgr);
        INDArray rw = getParamWithNoise(SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr);
        INDArray b = getParamWithNoise(SimpleRnnParamInitializer.BIAS_KEY, training, workspaceMgr);
        INDArray g = (hasLayerNorm() ? getParamWithNoise(SimpleRnnParamInitializer.GAIN_KEY, training, workspaceMgr) : null);
        INDArray gx = (g != null ? g.get(interval(0, 0, true), interval(0, nOut)) : null);
        INDArray gr = (g != null ? g.get(interval(0, 0, true), interval(nOut, nOut * 2)) : null);

        INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, w.dataType(), new long[]{m, nOut, tsLength}, 'f');
        INDArray outZ = (forBackprop ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, w.dataType(), out.shape()) : null);
        INDArray outPreNorm = (forBackprop && hasLayerNorm() ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, w.dataType(), out.shape(), 'f') : null);
        INDArray recPreNorm = (forBackprop && hasLayerNorm() ? workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, w.dataType(), out.shape(), 'f') : null);

        if(input.ordering() != 'f' || Shape.strideDescendingCAscendingF(input))
            input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'f');

        //TODO implement 'mmul across time' optimization

        if(!hasLayerNorm()) {
            //Minor performance optimization: do the "add bias" first:
            Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, b, out, 1));
        }

        IActivation a = layerConf().getActivationFn();

        for( int i = 0; i < tsLength; i++) {
            //out = activationFn(in*w + last*rw + bias)
            INDArray currOut = out.get(all(), all(), point(i)); //F order
            INDArray currIn = input.get(all(), all(), point(i));
            if(hasLayerNorm()){
                INDArray currOutPreNorm = (forBackprop ? outPreNorm : out).get(all(), all(), point(i));
                Nd4j.gemm(currIn, w, currOutPreNorm, false, false, 1.0, 0.0);
                Nd4j.getExecutioner().exec(new LayerNorm(currOutPreNorm, gx, b, currOut, true, 1));
            }else{
                Nd4j.gemm(currIn, w, currOut, false, false, 1.0, 1.0);  //beta = 1.0 to keep previous contents (bias)
            }

            if(i > 0 || prevStepOut != null) {
                if(hasLayerNorm()){
                    INDArray currRecPreNorm = forBackprop ? recPreNorm.get(all(), all(), point(i)) : workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');;
                    Nd4j.gemm(prevStepOut, rw, currRecPreNorm, false, false, 1.0, 0.0);
                    INDArray recNorm = workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');
                    Nd4j.getExecutioner().exec(new LayerNorm(currRecPreNorm, gr, recNorm, true, 1));
                    currOut.addi(recNorm);
                }else {
                    Nd4j.gemm(prevStepOut, rw, currOut, false, false, 1.0, 1.0);    //beta = 1.0 to keep previous contents
                }
            }

            if(forBackprop){
                outZ.get(all(), all(), point(i)).assign(currOut);
            }

            a.getActivation(currOut, training);

            if( maskArray != null){
                //If mask array is present: Also need to zero out errors to avoid sending anything but 0s to layer below for masked steps
                INDArray maskCol = maskArray.getColumn(i, true).castTo(dataType);
                currOut.muliColumnVector(maskCol);
            }

            prevStepOut = currOut;
        }

        //Apply mask, if present:
        if(maskArray != null) {
            //Mask should be shape [minibatch, tsLength]
            INDArray mask = maskArray.castTo(dataType);
            Nd4j.getExecutioner().exec(new BroadcastMulOp(out, mask, out, 0, 2));
            if(forBackprop){
                Nd4j.getExecutioner().exec(new BroadcastMulOp(outZ, mask, outZ, 0, 2));
            }
        }
        if (!forBackprop) {
            out = permuteIfNWC(out);
            outZ = permuteIfNWC(outZ);
            outPreNorm = permuteIfNWC(outPreNorm);
            recPreNorm = permuteIfNWC(recPreNorm);
        }
        return new Quad<>(out, outZ, outPreNorm, recPreNorm);
    }

    @Override
    public boolean hasLayerNorm(){
        return layerConf().hasLayerNorm();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy