org.deeplearning4j.nn.layers.DropoutLayer Maven / Gradle / Ivy
package org.deeplearning4j.nn.layers;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* Created by davekale on 12/7/16.
*/
public class DropoutLayer extends BaseLayer {
public DropoutLayer(NeuralNetConfiguration conf) {
super(conf);
}
public DropoutLayer(NeuralNetConfiguration conf, INDArray input) {
super(conf, input);
}
@Override
public double calcL2(boolean backpropParamsOnly) {
return 0;
}
@Override
public double calcL1(boolean backpropParamsOnly) {
return 0;
}
@Override
public Type type() {
return Type.FEED_FORWARD;
}
@Override
public void fit(INDArray input) {}
@Override
public Pair backpropGradient(INDArray epsilon) {
INDArray delta = epsilon.dup();
if (maskArray != null) {
delta.muliColumnVector(maskArray);
}
Gradient ret = new DefaultGradient();
return new Pair<>(ret, delta);
}
@Override
public INDArray preOutput(boolean training) {
if (input == null) {
throw new IllegalArgumentException("Cannot perform forward pass with null input " + layerId());
}
INDArray dummy = input;
applyDropOutIfNecessary(training);
INDArray ret;
if (training) {
//dup required: need to keep original input for backprop
ret = input.dup();
} else {
ret = input;
}
if (maskArray != null) {
ret.muliColumnVector(maskArray);
}
return ret;
}
@Override
public INDArray activate(boolean training) {
INDArray z = preOutput(training);
return z;
}
@Override
public Layer transpose() {
throw new UnsupportedOperationException("Not supported - " + layerId());
}
@Override
public boolean isPretrainLayer() {
return false;
}
@Override
public Gradient calcGradient(Gradient layerError, INDArray indArray) {
throw new UnsupportedOperationException("Not supported " + layerId());
}
@Override
public void merge(Layer layer, int batchSize) {
throw new UnsupportedOperationException("Not supported - " + layerId());
}
@Override
public INDArray params() {
return null;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy