org.deeplearning4j.nn.conf.layers.FeedForwardLayer Maven / Gradle / Ivy
package org.deeplearning4j.nn.conf.layers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
/**
* Created by jeffreytang on 7/21/15.
*/
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public abstract class FeedForwardLayer extends BaseLayer {
protected int nIn;
protected int nOut;
public FeedForwardLayer(Builder builder) {
super(builder);
this.nIn = builder.nIn;
this.nOut = builder.nOut;
}
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || (inputType.getType() != InputType.Type.FF
&& inputType.getType() != InputType.Type.CNNFlat)) {
throw new IllegalStateException("Invalid input type (layer index = " + layerIndex + ", layer name=\""
+ getLayerName() + "\"): expected FeedForward input type. Got: " + inputType);
}
return InputType.feedForward(nOut);
}
@Override
public void setNIn(InputType inputType, boolean override) {
if (inputType == null || (inputType.getType() != InputType.Type.FF
&& inputType.getType() != InputType.Type.CNNFlat)) {
throw new IllegalStateException("Invalid input type (layer name=\"" + getLayerName()
+ "\"): expected FeedForward input type. Got: " + inputType);
}
if (nIn <= 0 || override) {
if (inputType.getType() == InputType.Type.FF) {
InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward) inputType;
this.nIn = f.getSize();
} else {
InputType.InputTypeConvolutionalFlat f = (InputType.InputTypeConvolutionalFlat) inputType;
this.nIn = f.getFlattenedSize();
}
}
}
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
if (inputType == null) {
throw new IllegalStateException(
"Invalid input for layer (layer name = \"" + getLayerName() + "\"): input type is null");
}
switch (inputType.getType()) {
case FF:
case CNNFlat:
//FF -> FF and CNN (flattened format) -> FF: no preprocessor necessary
return null;
case RNN:
//RNN -> FF
return new RnnToFeedForwardPreProcessor();
case CNN:
//CNN -> FF
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return new CnnToFeedForwardPreProcessor(c.getHeight(), c.getWidth(), c.getDepth());
default:
throw new RuntimeException("Unknown input type: " + inputType);
}
}
@Override
public double getL1ByParam(String paramName) {
switch (paramName) {
case DefaultParamInitializer.WEIGHT_KEY:
return l1;
case DefaultParamInitializer.BIAS_KEY:
return l1Bias;
default:
throw new IllegalStateException("Unknown parameter: \"" + paramName + "\"");
}
}
@Override
public double getL2ByParam(String paramName) {
switch (paramName) {
case DefaultParamInitializer.WEIGHT_KEY:
return l2;
case DefaultParamInitializer.BIAS_KEY:
return l2Bias;
default:
throw new IllegalStateException("Unknown parameter: \"" + paramName + "\"");
}
}
@Override
public double getLearningRateByParam(String paramName) {
switch (paramName) {
case DefaultParamInitializer.WEIGHT_KEY:
return learningRate;
case DefaultParamInitializer.BIAS_KEY:
if (!Double.isNaN(biasLearningRate)) {
//Bias learning rate has been explicitly set
return biasLearningRate;
} else {
return learningRate;
}
default:
throw new IllegalStateException("Unknown parameter: \"" + paramName + "\"");
}
}
@Override
public boolean isPretrainParam(String paramName) {
return false; //No pretrain params in standard FF layers
}
public abstract static class Builder> extends BaseLayer.Builder {
protected int nIn = 0;
protected int nOut = 0;
public T nIn(int nIn) {
this.nIn = nIn;
return (T) this;
}
public T nOut(int nOut) {
this.nOut = nOut;
return (T) this;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy