All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */
package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.util.Dropout;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;

import java.util.Arrays;

/**Recurrent Neural Network Output Layer.
* Handles calculation of gradients etc for various objective functions.
* Functionally the same as OutputLayer, but handles output and label reshaping * automatically.
* Input and output activations are same as other RNN layers: 3 dimensions with shape * [miniBatchSize,nIn,timeSeriesLength] and [miniBatchSize,nOut,timeSeriesLength] respectively. * @author Alex Black * @see BaseOutputLayer, OutputLayer */ public class RnnOutputLayer extends BaseOutputLayer { public RnnOutputLayer(NeuralNetConfiguration conf) { super(conf); } public RnnOutputLayer(NeuralNetConfiguration conf, INDArray input) { super(conf, input); } @Override public Pair backpropGradient(INDArray epsilon) { if (input.rank() != 3) throw new UnsupportedOperationException("Input is not rank 3"); INDArray inputTemp = input; this.input = TimeSeriesUtils.reshape3dTo2d(input); Pair gradAndEpsilonNext = super.backpropGradient(epsilon); this.input = inputTemp; INDArray epsilon2d = gradAndEpsilonNext.getSecond(); INDArray epsilon3d = TimeSeriesUtils.reshape2dTo3d(epsilon2d, input.size(0)); return new Pair<>(gradAndEpsilonNext.getFirst(), epsilon3d); } /**{@inheritDoc} */ @Override public double f1Score(INDArray examples, INDArray labels) { if (examples.rank() == 3) examples = TimeSeriesUtils.reshape3dTo2d(examples); if (labels.rank() == 3) labels = TimeSeriesUtils.reshape3dTo2d(labels); return super.f1Score(examples, labels); } public INDArray getInput() { return input; } @Override public Layer.Type type() { return Layer.Type.RECURRENT; } @Override public INDArray preOutput(INDArray x, boolean training) { setInput(x); return TimeSeriesUtils.reshape2dTo3d(preOutput2d(training), input.size(0)); } @Override protected INDArray preOutput2d(boolean training) { if (input.rank() == 3) { //Case when called from RnnOutputLayer INDArray inputTemp = input; input = TimeSeriesUtils.reshape3dTo2d(input); INDArray out = super.preOutput(input, training); this.input = inputTemp; return out; } else { //Case when called from BaseOutputLayer INDArray out = super.preOutput(input, training); return out; } } @Override protected INDArray getLabels2d() { if (labels.rank() == 3) return TimeSeriesUtils.reshape3dTo2d(labels); return labels; } @Override public INDArray output(INDArray input) { if (input.rank() != 3) throw new IllegalArgumentException("Input must be rank 3 (is: " + input.rank()); //Returns 3d activations from 3d input setInput(input); return output(false); } @Override public INDArray output(boolean training) { //Assume that input is 3d if (input.rank() != 3) throw new IllegalArgumentException("input must be rank 3"); INDArray preOutput2d = preOutput2d(training); //if(conf.getLayer().getActivationFunction().equals("softmax")) { if (conf.getLayer().getActivationFn() instanceof ActivationSoftmax) { INDArray out2d = Nd4j.getExecutioner().execAndReturn(new SoftMax(preOutput2d)); if (maskArray != null) { out2d.muliColumnVector(maskArray); } return TimeSeriesUtils.reshape2dTo3d(out2d, input.size(0)); } if (training) applyDropOutIfNecessary(training); INDArray origInput = input; this.input = TimeSeriesUtils.reshape3dTo2d(input); INDArray out = super.activate(true); this.input = origInput; if (maskArray != null) { out.muliColumnVector(maskArray); } return TimeSeriesUtils.reshape2dTo3d(out, input.size(0)); } @Override public INDArray activate(boolean training) { if (input.rank() != 3) throw new UnsupportedOperationException("Input must be rank 3"); INDArray b = getParam(DefaultParamInitializer.BIAS_KEY); INDArray W = getParam(DefaultParamInitializer.WEIGHT_KEY); if (conf.isUseDropConnect() && training) { W = Dropout.applyDropConnect(this, DefaultParamInitializer.WEIGHT_KEY); } INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input); //INDArray act2d = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf.getLayer().getActivationFunction(), // input2d.mmul(W).addiRowVector(b))); INDArray act2d = conf.getLayer().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training); if (maskArray != null) { act2d.muliColumnVector(maskArray); } return TimeSeriesUtils.reshape2dTo3d(act2d, input.size(0)); } @Override public void setMaskArray(INDArray maskArray) { if (maskArray != null) { //Two possible cases: //(a) per time step masking - rank 2 mask array -> reshape to rank 1 (column vector) //(b) per output masking - rank 3 mask array -> reshape to rank 2 ( if (maskArray.rank() == 2) { this.maskArray = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(maskArray); } else if (maskArray.rank() == 3) { this.maskArray = TimeSeriesUtils.reshape3dTo2d(maskArray); } else { throw new UnsupportedOperationException("Invalid mask array: must be rank 2 or 3 (got: rank " + maskArray.rank() + ", shape = " + Arrays.toString(maskArray.shape()) + ")"); } } else { this.maskArray = null; } } @Override public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { //If the *input* mask array is present and active, we should use it to mask the output if (maskArray != null && currentMaskState == MaskState.Active) { this.inputMaskArray = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(maskArray); this.inputMaskArrayState = currentMaskState; } else { this.inputMaskArray = null; this.inputMaskArrayState = null; } return null; //Last layer in network } /**Compute the score for each example individually, after labels and input have been set. * * @param fullNetworkL1 L1 regularization term for the entire network (or, 0.0 to not include regularization) * @param fullNetworkL2 L2 regularization term for the entire network (or, 0.0 to not include regularization) * @return A column INDArray of shape [numExamples,1], where entry i is the score of the ith example */ @Override public INDArray computeScoreForExamples(double fullNetworkL1, double fullNetworkL2) { //For RNN: need to sum up the score over each time step before returning. if (input == null || labels == null) throw new IllegalStateException("Cannot calculate score without input and labels"); INDArray preOut = preOutput2d(false); ILossFunction lossFunction = layerConf().getLossFn(); INDArray scoreArray = lossFunction.computeScoreArray(getLabels2d(), preOut, layerConf().getActivationFn(), maskArray); //scoreArray: shape [minibatch*timeSeriesLength, 1] //Reshape it to [minibatch, timeSeriesLength] then sum over time step INDArray scoreArrayTs = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(scoreArray, input.size(0)); INDArray summedScores = scoreArrayTs.sum(1); double l1l2 = fullNetworkL1 + fullNetworkL2; if (l1l2 != 0.0) { summedScores.addi(l1l2); } return summedScores; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy