All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.ops.SDRNN Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.autodiff.samediff.ops;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.*;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*;

import java.util.Arrays;
import java.util.List;

/**
 * SameDiff Recurrent Neural Network operations
* Accessible via {@link SameDiff#rnn()}
* See also {@link SDNN} (accessible via {@link SameDiff#nn()} for general neural network ops.
* See also {@link SDCNN} (accessible via {@link SameDiff#cnn()} for convolutional neural network ops.
* * @author Alex Black */ public class SDRNN extends SDOps { public SDRNN(SameDiff sameDiff) { super(sameDiff); } /** * The gru cell * * @param configuration the configuration to use * @return */ public List gru(GRUCellConfiguration configuration) { GRUCell c = new GRUCell(sd, configuration); return Arrays.asList(c.outputVariables()); } /** * The gru cell * * @param baseName the base name for the gru cell * @param configuration the configuration to use * @return */ public List gru(String baseName, GRUCellConfiguration configuration) { GRUCell c = new GRUCell(sd, configuration); return Arrays.asList(c.outputVariables(baseName)); } /** * LSTM unit * * @param baseName the base name for outputs * @param configuration the configuration to use * @return */ public SDVariable lstmCell(String baseName, LSTMCellConfiguration configuration) { return new LSTMCell(sd, configuration).outputVariables(baseName)[0]; } public List lstmBlockCell(String name, LSTMBlockCellConfiguration configuration){ SDVariable[] v = new LSTMBlockCell(sd, configuration).outputVariables(name); return Arrays.asList(v); } public List lstmLayer(String name, LSTMConfiguration configuration){ SDVariable[] v = new LSTMLayer(sd, configuration).outputVariables(name); return Arrays.asList(v); } /** * Simple recurrent unit * * @param configuration the configuration for the sru * @return */ public SDVariable sru(SRUConfiguration configuration) { return new SRU(sd, configuration).outputVariables()[0]; } /** * Simiple recurrent unit * * @param baseName the base name to use for output variables * @param configuration the configuration for the sru * @return */ public SDVariable sru(String baseName, SRUConfiguration configuration) { return new SRU(sd, configuration).outputVariables(baseName)[0]; } /** * An sru cell * * @param configuration the configuration for the sru cell * @return */ public SDVariable sruCell(SRUCellConfiguration configuration) { return new SRUCell(sd, configuration).outputVariables()[0]; } /** * An sru cell * * @param baseName the base name to use for the output variables * @param configuration the configuration for the sru cell * @return */ public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) { return new SRUCell(sd, configuration).outputVariables(baseName)[0]; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy