All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.api.layers.RecurrentLayer Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.api.layers;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;

import java.util.Map;

public interface RecurrentLayer extends Layer {

    /**
     * Do one or more time steps using the previous time step state stored in stateMap.
* Can be used to efficiently do forward pass one or n-steps at a time (instead of doing * forward pass always from t=0)
* If stateMap is empty, default initialization (usually zeros) is used
* Implementations also update stateMap at the end of this method * * @param input Input to this layer * @return activations */ INDArray rnnTimeStep(INDArray input, LayerWorkspaceMgr workspaceMgr); /** * Returns a shallow copy of the RNN stateMap (that contains the stored history for use in methods such * as rnnTimeStep */ Map rnnGetPreviousState(); /** * Set the stateMap (stored history). Values set using this method will be used in next call to rnnTimeStep() */ void rnnSetPreviousState(Map stateMap); /** * Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState() */ void rnnClearPreviousState(); /** * Similar to rnnTimeStep, this method is used for activations using the state * stored in the stateMap as the initialization. However, unlike rnnTimeStep this * method does not alter the stateMap; therefore, unlike rnnTimeStep, multiple calls to * this method (with identical input) will:
* (a) result in the same output
* (b) leave the state maps (both stateMap and tBpttStateMap) in an identical state * * @param input Layer input * @param training if true: training. Otherwise: test * @param storeLastForTBPTT If true: store the final state in tBpttStateMap for use in truncated BPTT training * @return Layer activations */ INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT, LayerWorkspaceMgr workspaceMg); /** * Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer. * The TBPTT state is used to store intermediate activations/state between updating parameters when doing * TBPTT learning * * @return State for the RNN layer */ Map rnnGetTBPTTState(); /** * Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer. * The TBPTT state is used to store intermediate activations/state between updating parameters when doing * TBPTT learning * * @param state TBPTT state to set */ void rnnSetTBPTTState(Map state); /** * Truncated BPTT equivalent of Layer.backpropGradient(). * Primary difference here is that forward pass in the context of BPTT is that we do * forward pass using stored state for truncated BPTT vs. from zero initialization * for standard BPTT. */ Pair tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr); }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy