org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.layers.variational;
import lombok.*;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.LayerValidation;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.VariationalAutoencoderParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.util.ArrayUtil;
import java.util.Collection;
import java.util.Map;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
public class VariationalAutoencoder extends BasePretrainNetwork {
private int[] encoderLayerSizes;
private int[] decoderLayerSizes;
private ReconstructionDistribution outputDistribution;
private IActivation pzxActivationFn;
private int numSamples;
private VariationalAutoencoder(Builder builder) {
super(builder);
this.encoderLayerSizes = builder.encoderLayerSizes;
this.decoderLayerSizes = builder.decoderLayerSizes;
this.outputDistribution = builder.outputDistribution;
this.pzxActivationFn = builder.pzxActivationFn;
this.numSamples = builder.numSamples;
}
@Override
public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners,
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
LayerValidation.assertNInNOutSet("VariationalAutoencoder", getLayerName(), layerIndex, getNIn(), getNOut());
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret =
new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(conf, networkDataType);
ret.setListeners(trainingListeners);
ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView);
Map paramTable = initializer().init(conf, layerParamsView, initializeParams);
ret.setParamTable(paramTable);
ret.setConf(conf);
return ret;
}
@Override
public ParamInitializer initializer() {
return VariationalAutoencoderParamInitializer.getInstance();
}
@Override
public boolean isPretrainParam(String paramName) {
if (paramName.startsWith(VariationalAutoencoderParamInitializer.DECODER_PREFIX)) {
return true;
}
if (paramName.startsWith(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_PREFIX)) {
return true;
}
if (paramName.startsWith(VariationalAutoencoderParamInitializer.PXZ_PREFIX)) {
return true;
}
return false;
}
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
//For training: we'll assume unsupervised pretraining, as this has higher memory requirements
InputType outputType = getOutputType(-1, inputType);
val actElementsPerEx = outputType.arrayElementsPerExample();
val numParams = initializer().numParams(this);
int updaterStateSize = (int) getIUpdater().stateSize(numParams);
int inferenceWorkingMemSizePerEx = 0;
//Forward pass size through the encoder:
for (int i = 1; i < encoderLayerSizes.length; i++) {
inferenceWorkingMemSizePerEx += encoderLayerSizes[i];
}
//Forward pass size through the decoder, during training
//p(Z|X) mean and stdev; pzxSigmaSquared, pzxSigma -> all size equal to nOut
long decoderFwdSizeWorking = 4 * nOut;
//plus, nSamples * decoder size
//For each decoding: random sample (nOut), z (nOut), activations for each decoder layer
decoderFwdSizeWorking += numSamples * (2 * nOut + ArrayUtil.sum(getDecoderLayerSizes()));
//Plus, component of score
decoderFwdSizeWorking += nOut;
//Backprop size through the decoder and decoder: approx. 2x forward pass size
long trainWorkingMemSize = 2 * (inferenceWorkingMemSizePerEx + decoderFwdSizeWorking);
if (getIDropout() != null) {
if (false) {
//TODO drop connect
//Dup the weights... note that this does NOT depend on the minibatch size...
} else {
//Assume we dup the input
trainWorkingMemSize += inputType.arrayElementsPerExample();
}
}
return new LayerMemoryReport.Builder(layerName, VariationalAutoencoder.class, inputType, outputType)
.standardMemory(numParams, updaterStateSize)
.workingMemory(0, inferenceWorkingMemSizePerEx, 0, trainWorkingMemSize)
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
.build();
}
@Getter
@Setter
public static class Builder extends BasePretrainNetwork.Builder {
/**
* Size of the encoder layers, in units. Each encoder layer is functionally equivalent to a {@link
* org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers (set via
* {@link #decoderLayerSizes(int...)} is similar to the encoder layers.
*
*/
private int[] encoderLayerSizes = new int[] {100};
/**
* Size of the decoder layers, in units. Each decoder layer is functionally equivalent to a {@link
* org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers is similar
* to the encoder layers (set via {@link #encoderLayerSizes(int...)}.
*
*/
private int[] decoderLayerSizes = new int[] {100};
/**
* The reconstruction distribution for the data given the hidden state - i.e., P(data|Z).
This should be
* selected carefully based on the type of data being modelled. For example:
- {@link
* GaussianReconstructionDistribution} + {identity or tanh} for real-valued (Gaussian) data
- {@link
* BernoulliReconstructionDistribution} + sigmoid for binary-valued (0 or 1) data
*
*/
private ReconstructionDistribution outputDistribution = new GaussianReconstructionDistribution(Activation.TANH);
/**
* Activation function for the input to P(z|data).
Care should be taken with this, as some activation
* functions (relu, etc) are not suitable due to being bounded in range [0,infinity).
*
*/
private IActivation pzxActivationFn = new ActivationIdentity();
/**
* Set the number of samples per data point (from VAE state Z) used when doing pretraining. Default value: 1.
*
* This is parameter L from Kingma and Welling: "In our experiments we found that the number of samples L per
* datapoint can be set to 1 as long as the minibatch size M was large enough, e.g. M = 100."
*
*/
private int numSamples = 1;
/**
* Size of the encoder layers, in units. Each encoder layer is functionally equivalent to a {@link
* org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers (set via
* {@link #decoderLayerSizes(int...)} is similar to the encoder layers.
*
* @param encoderLayerSizes Size of each encoder layer in the variational autoencoder
*/
public Builder encoderLayerSizes(int... encoderLayerSizes) {
this.setEncoderLayerSizes(encoderLayerSizes);
return this;
}
/**
* Size of the encoder layers, in units. Each encoder layer is functionally equivalent to a {@link
* org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers (set via
* {@link #decoderLayerSizes(int...)} is similar to the encoder layers.
*
* @param encoderLayerSizes Size of each encoder layer in the variational autoencoder
*/
public void setEncoderLayerSizes(int... encoderLayerSizes) {
if (encoderLayerSizes == null || encoderLayerSizes.length < 1) {
throw new IllegalArgumentException("Encoder layer sizes array must have length > 0");
}
this.encoderLayerSizes = encoderLayerSizes;
}
/**
* Size of the decoder layers, in units. Each decoder layer is functionally equivalent to a {@link
* org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers is similar
* to the encoder layers (set via {@link #encoderLayerSizes(int...)}.
*
* @param decoderLayerSizes Size of each deccoder layer in the variational autoencoder
*/
public Builder decoderLayerSizes(int... decoderLayerSizes) {
this.setDecoderLayerSizes(decoderLayerSizes);
return this;
}
/**
* Size of the decoder layers, in units. Each decoder layer is functionally equivalent to a {@link
* org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers is similar
* to the encoder layers (set via {@link #encoderLayerSizes(int...)}.
*
* @param decoderLayerSizes Size of each deccoder layer in the variational autoencoder
*/
public void setDecoderLayerSizes(int... decoderLayerSizes) {
if (decoderLayerSizes == null || decoderLayerSizes.length < 1) {
throw new IllegalArgumentException("Decoder layer sizes array must have length > 0");
}
this.decoderLayerSizes = decoderLayerSizes;
}
/**
* The reconstruction distribution for the data given the hidden state - i.e., P(data|Z).
This should be
* selected carefully based on the type of data being modelled. For example:
- {@link
* GaussianReconstructionDistribution} + {identity or tanh} for real-valued (Gaussian) data
- {@link
* BernoulliReconstructionDistribution} + sigmoid for binary-valued (0 or 1) data
*
* @param distribution Reconstruction distribution
*/
public Builder reconstructionDistribution(ReconstructionDistribution distribution) {
this.setOutputDistribution(distribution);
return this;
}
/**
* Configure the VAE to use the specified loss function for the reconstruction, instead of a
* ReconstructionDistribution. Note that this is NOT following the standard VAE design (as per Kingma &
* Welling), which assumes a probabilistic output - i.e., some p(x|z). It is however a valid network
* configuration, allowing for optimization of more traditional objectives such as mean squared error.
Note:
* clearly, setting the loss function here will override any previously set recontruction distribution
*
* @param outputActivationFn Activation function for the output/reconstruction
* @param lossFunction Loss function to use
*/
public Builder lossFunction(IActivation outputActivationFn, LossFunctions.LossFunction lossFunction) {
return lossFunction(outputActivationFn, lossFunction.getILossFunction());
}
/**
* Configure the VAE to use the specified loss function for the reconstruction, instead of a
* ReconstructionDistribution. Note that this is NOT following the standard VAE design (as per Kingma &
* Welling), which assumes a probabilistic output - i.e., some p(x|z). It is however a valid network
* configuration, allowing for optimization of more traditional objectives such as mean squared error.
Note:
* clearly, setting the loss function here will override any previously set recontruction distribution
*
* @param outputActivationFn Activation function for the output/reconstruction
* @param lossFunction Loss function to use
*/
public Builder lossFunction(Activation outputActivationFn, LossFunctions.LossFunction lossFunction) {
return lossFunction(outputActivationFn.getActivationFunction(), lossFunction.getILossFunction());
}
/**
* Configure the VAE to use the specified loss function for the reconstruction, instead of a
* ReconstructionDistribution. Note that this is NOT following the standard VAE design (as per Kingma &
* Welling), which assumes a probabilistic output - i.e., some p(x|z). It is however a valid network
* configuration, allowing for optimization of more traditional objectives such as mean squared error.
Note:
* clearly, setting the loss function here will override any previously set recontruction distribution
*
* @param outputActivationFn Activation function for the output/reconstruction
* @param lossFunction Loss function to use
*/
public Builder lossFunction(IActivation outputActivationFn, ILossFunction lossFunction) {
return reconstructionDistribution(new LossFunctionWrapper(outputActivationFn, lossFunction));
}
/**
* Activation function for the input to P(z|data).
Care should be taken with this, as some activation
* functions (relu, etc) are not suitable due to being bounded in range [0,infinity).
*
* @param activationFunction Activation function for p(z|x)
*/
public Builder pzxActivationFn(IActivation activationFunction) {
this.setPzxActivationFn(activationFunction);
return this;
}
/**
* Activation function for the input to P(z|data).
Care should be taken with this, as some activation
* functions (relu, etc) are not suitable due to being bounded in range [0,infinity).
*
* @param activation Activation function for p(z|x)
*/
public Builder pzxActivationFunction(Activation activation) {
return pzxActivationFn(activation.getActivationFunction());
}
/**
* Set the size of the VAE state Z. This is the output size during standard forward pass, and the size of the
* distribution P(Z|data) during pretraining.
*
* @param nOut Size of P(Z|data) and output size
*/
@Override
public Builder nOut(int nOut) {
super.nOut(nOut);
return this;
}
/**
* Set the number of samples per data point (from VAE state Z) used when doing pretraining. Default value: 1.
*
* This is parameter L from Kingma and Welling: "In our experiments we found that the number of samples L per
* datapoint can be set to 1 as long as the minibatch size M was large enough, e.g. M = 100."
*
* @param numSamples Number of samples per data point for pretraining
*/
public Builder numSamples(int numSamples) {
this.setNumSamples(numSamples);
return this;
}
@Override
@SuppressWarnings("unchecked")
public VariationalAutoencoder build() {
return new VariationalAutoencoder(this);
}
}
}