Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs Maven / Gradle / Ivy
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
import java.util.Arrays;
import java.util.List;
import lombok.AccessLevel;
import lombok.Getter;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
/**
* The outputs of a LSTM layer ({@link LSTMLayer}.
*/
@Getter
public class LSTMLayerOutputs {
private RnnDataFormat dataFormat;
/**
* Output - input modulation gate activations.
* Shape depends on data format (in layer config):
* TNS -> [timeSteps, batchSize, numUnits]
* NST -> [batchSize, numUnits, timeSteps]
* NTS -> [batchSize, timeSteps, numUnits]
*/
private SDVariable i;
/**
* Activations, cell state (pre tanh).
* Shape depends on data format (in layer config):
* TNS -> [timeSteps, batchSize, numUnits]
* NST -> [batchSize, numUnits, timeSteps]
* NTS -> [batchSize, timeSteps, numUnits]
*/
private SDVariable c;
/**
* Output - forget gate activations.
* Shape depends on data format (in layer config):
* TNS -> [timeSteps, batchSize, numUnits]
* NST -> [batchSize, numUnits, timeSteps]
* NTS -> [batchSize, timeSteps, numUnits]
*/
private SDVariable f;
/**
* Output - output gate activations.
* Shape depends on data format (in layer config):
* TNS -> [timeSteps, batchSize, numUnits]
* NST -> [batchSize, numUnits, timeSteps]
* NTS -> [batchSize, timeSteps, numUnits]
*/
private SDVariable o;
/**
* Output - input gate activations.
* Shape depends on data format (in layer config):
* TNS -> [timeSteps, batchSize, numUnits]
* NST -> [batchSize, numUnits, timeSteps]
* NTS -> [batchSize, timeSteps, numUnits]
*/
private SDVariable z;
/**
* Cell state, post tanh.
* Shape depends on data format (in layer config):
* TNS -> [timeSteps, batchSize, numUnits]
* NST -> [batchSize, numUnits, timeSteps]
* NTS -> [batchSize, timeSteps, numUnits]
*/
private SDVariable h;
/**
* Current cell output.
* Shape depends on data format (in layer config):
* TNS -> [timeSteps, batchSize, numUnits]
* NST -> [batchSize, numUnits, timeSteps]
* NTS -> [batchSize, timeSteps, numUnits]
*/
private SDVariable y;
public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){
Preconditions.checkArgument(outputs.length == 7,
"Must have 7 LSTM layer outputs, got %s", outputs.length);
i = outputs[0];
c = outputs[1];
f = outputs[2];
o = outputs[3];
z = outputs[4];
h = outputs[5];
y = outputs[6];
this.dataFormat = dataFormat;
}
/**
* Get all outputs returned by the cell.
*/
public List getAllOutputs(){
return Arrays.asList(i, c, f, o, z, h, y);
}
/**
* Get y, the output of the cell for all time steps.
*
* Shape depends on data format (in layer config):
* TNS -> [timeSteps, batchSize, numUnits]
* NST -> [batchSize, numUnits, timeSteps]
* NTS -> [batchSize, timeSteps, numUnits]
*/
public SDVariable getOutput(){
return y;
}
/**
* Get c, the cell's state for all time steps.
*
* Shape depends on data format (in layer config):
* TNS -> [timeSteps, batchSize, numUnits]
* NST -> [batchSize, numUnits, timeSteps]
* NTS -> [batchSize, timeSteps, numUnits]
*/
public SDVariable getState(){
return c;
}
private SDVariable lastOutput = null;
/**
* Get y, the output of the cell, for the last time step.
*
* Has shape [batchSize, numUnits].
*/
public SDVariable getLastOutput(){
if(lastOutput != null)
return lastOutput;
switch (dataFormat){
case TNS:
lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
break;
case NST:
lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
break;
case NTS:
lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
break;
}
return lastOutput;
}
private SDVariable lastState = null;
/**
* Get c, the state of the cell, for the last time step.
*
* Has shape [batchSize, numUnits].
*/
public SDVariable getLastState(){
if(lastState != null)
return lastState;
switch (dataFormat){
case TNS:
lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
break;
case NST:
lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
break;
case NTS:
lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
break;
}
return lastState;
}
}