All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public abstract class BaseRecurrentLayer extends BaseLayer
        implements RecurrentLayer {

    /**
     * stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
     */
    protected Map stateMap = new ConcurrentHashMap<>();

    /**
     * State map for use specifically in truncated BPTT training. Whereas stateMap contains the
     * state from which forward pass is initialized, the tBpttStateMap contains the state at the
     * end of the last truncated bptt
     */
    protected Map tBpttStateMap = new ConcurrentHashMap<>();

    public BaseRecurrentLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    public BaseRecurrentLayer(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    /**
     * Returns a shallow copy of the stateMap
     */
    @Override
    public Map rnnGetPreviousState() {
        return new HashMap<>(stateMap);
    }

    /**
     * Set the state map. Values set using this method will be used
     * in next call to rnnTimeStep()
     */
    @Override
    public void rnnSetPreviousState(Map stateMap) {
        this.stateMap.clear();
        this.stateMap.putAll(stateMap);
    }

    /**
     * Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
     */
    @Override
    public void rnnClearPreviousState() {
        stateMap.clear();
        tBpttStateMap.clear();
    }

    @Override
    public Map rnnGetTBPTTState() {
        return new HashMap<>(tBpttStateMap);
    }

    @Override
    public void rnnSetTBPTTState(Map state) {
        tBpttStateMap.clear();
        tBpttStateMap.putAll(state);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy