All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.conf.layers.variational;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.VariationalAutoencoderParamInitializer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.util.LayerValidation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.util.ArrayUtil;

import java.util.Collection;
import java.util.Map;

/**
 * Variational Autoencoder layer
 *

* See: Kingma & Welling, 2013: Auto-Encoding Variational Bayes - https://arxiv.org/abs/1312.6114 *

* This implementation allows multiple encoder and decoder layers, the number and sizes of which can be set independently. *

* A note on scores during pretraining: This implementation minimizes the negative of the variational lower bound objective * as described in Kingma & Welling; the mathematics in that paper is based on maximization of the variational lower bound instead. * Thus, scores reported during pretraining in DL4J are the negative of the variational lower bound equation in the paper. * The backpropagation and learning procedure is otherwise as described there. * * @author Alex Black */ @Data @NoArgsConstructor @EqualsAndHashCode(callSuper = true) public class VariationalAutoencoder extends BasePretrainNetwork { private int[] encoderLayerSizes; private int[] decoderLayerSizes; private ReconstructionDistribution outputDistribution; private IActivation pzxActivationFn; private int numSamples; private VariationalAutoencoder(Builder builder) { super(builder); this.encoderLayerSizes = builder.encoderLayerSizes; this.decoderLayerSizes = builder.decoderLayerSizes; this.outputDistribution = builder.outputDistribution; this.pzxActivationFn = builder.pzxActivationFn; this.numSamples = builder.numSamples; } @Override public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams) { LayerValidation.assertNInNOutSet("VariationalAutoencoder", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret = new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(conf); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(conf, layerParamsView, initializeParams); ret.setParamTable(paramTable); ret.setConf(conf); return ret; } @Override public ParamInitializer initializer() { return VariationalAutoencoderParamInitializer.getInstance(); } @Override public double getLearningRateByParam(String paramName) { if (paramName.endsWith("b")) { if (!Double.isNaN(biasLearningRate)) { //Bias learning rate has been explicitly set return biasLearningRate; } else { return learningRate; } } else { return learningRate; } } @Override public double getL1ByParam(String paramName) { if (paramName.endsWith(VariationalAutoencoderParamInitializer.BIAS_KEY_SUFFIX)) return l1Bias; return l1; } @Override public double getL2ByParam(String paramName) { if (paramName.endsWith(VariationalAutoencoderParamInitializer.BIAS_KEY_SUFFIX)) return l2Bias; return l2; } @Override public boolean isPretrainParam(String paramName) { if (paramName.startsWith(VariationalAutoencoderParamInitializer.DECODER_PREFIX)) { return true; } if (paramName.startsWith(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_PREFIX)) { return true; } if (paramName.startsWith(VariationalAutoencoderParamInitializer.PXZ_PREFIX)) { return true; } return false; } @Override public LayerMemoryReport getMemoryReport(InputType inputType) { //For training: we'll assume unsupervised pretraining, as this has higher memory requirements InputType outputType = getOutputType(-1, inputType); int actElementsPerEx = outputType.arrayElementsPerExample(); int numParams = initializer().numParams(this); int updaterStateSize = (int) getIUpdater().stateSize(numParams); int inferenceWorkingMemSizePerEx = 0; //Forward pass size through the encoder: for (int i = 1; i < encoderLayerSizes.length; i++) { inferenceWorkingMemSizePerEx += encoderLayerSizes[i]; } //Forward pass size through the decoder, during training //p(Z|X) mean and stdev; pzxSigmaSquared, pzxSigma -> all size equal to nOut int decoderFwdSizeWorking = 4 * nOut; //plus, nSamples * decoder size //For each decoding: random sample (nOut), z (nOut), activations for each decoder layer decoderFwdSizeWorking += numSamples * (2 * nOut + ArrayUtil.sum(getDecoderLayerSizes())); //Plus, component of score decoderFwdSizeWorking += nOut; //Backprop size through the decoder and decoder: approx. 2x forward pass size int trainWorkingMemSize = 2 * (inferenceWorkingMemSizePerEx + decoderFwdSizeWorking); if (getDropOut() > 0) { if (false) { //TODO drop connect //Dup the weights... note that this does NOT depend on the minibatch size... } else { //Assume we dup the input trainWorkingMemSize += inputType.arrayElementsPerExample(); } } return new LayerMemoryReport.Builder(layerName, VariationalAutoencoder.class, inputType, outputType) .standardMemory(numParams, updaterStateSize) .workingMemory(0, inferenceWorkingMemSizePerEx, 0, trainWorkingMemSize) .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching .build(); } public static class Builder extends BasePretrainNetwork.Builder { private int[] encoderLayerSizes = new int[] {100}; private int[] decoderLayerSizes = new int[] {100}; private ReconstructionDistribution outputDistribution = new GaussianReconstructionDistribution(Activation.TANH); private IActivation pzxActivationFn = new ActivationIdentity(); private int numSamples = 1; /** * Size of the encoder layers, in units. Each encoder layer is functionally equivalent to a {@link org.deeplearning4j.nn.conf.layers.DenseLayer}. * Typically the number and size of the decoder layers (set via {@link #decoderLayerSizes(int...)} is similar to the encoder layers. * * @param encoderLayerSizes Size of each encoder layer in the variational autoencoder */ public Builder encoderLayerSizes(int... encoderLayerSizes) { if (encoderLayerSizes == null || encoderLayerSizes.length < 1) { throw new IllegalArgumentException("Encoder layer sizes array must have length > 0"); } this.encoderLayerSizes = encoderLayerSizes; return this; } /** * Size of the decoder layers, in units. Each decoder layer is functionally equivalent to a {@link org.deeplearning4j.nn.conf.layers.DenseLayer}. * Typically the number and size of the decoder layers is similar to the encoder layers (set via {@link #encoderLayerSizes(int...)}. * * @param decoderLayerSizes Size of each deccoder layer in the variational autoencoder */ public Builder decoderLayerSizes(int... decoderLayerSizes) { if (encoderLayerSizes == null || encoderLayerSizes.length < 1) { throw new IllegalArgumentException("Decoder layer sizes array must have length > 0"); } this.decoderLayerSizes = decoderLayerSizes; return this; } /** * The reconstruction distribution for the data given the hidden state - i.e., P(data|Z).
* This should be selected carefully based on the type of data being modelled. For example:
* - {@link GaussianReconstructionDistribution} + {identity or tanh} for real-valued (Gaussian) data
* - {@link BernoulliReconstructionDistribution} + sigmoid for binary-valued (0 or 1) data
* * @param distribution Reconstruction distribution */ public Builder reconstructionDistribution(ReconstructionDistribution distribution) { this.outputDistribution = distribution; return this; } /** * Configure the VAE to use the specified loss function for the reconstruction, instead of a ReconstructionDistribution. * Note that this is NOT following the standard VAE design (as per Kingma & Welling), which assumes a probabilistic * output - i.e., some p(x|z). It is however a valid network configuration, allowing for optimization of more traditional * objectives such as mean squared error.
* Note: clearly, setting the loss function here will override any previously set recontruction distribution * * @param outputActivationFn Activation function for the output/reconstruction * @param lossFunction Loss function to use */ public Builder lossFunction(IActivation outputActivationFn, LossFunctions.LossFunction lossFunction) { return lossFunction(outputActivationFn, lossFunction.getILossFunction()); } /** * Configure the VAE to use the specified loss function for the reconstruction, instead of a ReconstructionDistribution. * Note that this is NOT following the standard VAE design (as per Kingma & Welling), which assumes a probabilistic * output - i.e., some p(x|z). It is however a valid network configuration, allowing for optimization of more traditional * objectives such as mean squared error.
* Note: clearly, setting the loss function here will override any previously set recontruction distribution * * @param outputActivationFn Activation function for the output/reconstruction * @param lossFunction Loss function to use */ public Builder lossFunction(Activation outputActivationFn, LossFunctions.LossFunction lossFunction) { return lossFunction(outputActivationFn.getActivationFunction(), lossFunction.getILossFunction()); } /** * Configure the VAE to use the specified loss function for the reconstruction, instead of a ReconstructionDistribution. * Note that this is NOT following the standard VAE design (as per Kingma & Welling), which assumes a probabilistic * output - i.e., some p(x|z). It is however a valid network configuration, allowing for optimization of more traditional * objectives such as mean squared error.
* Note: clearly, setting the loss function here will override any previously set recontruction distribution * * @param outputActivationFn Activation function for the output/reconstruction * @param lossFunction Loss function to use */ public Builder lossFunction(IActivation outputActivationFn, ILossFunction lossFunction) { return reconstructionDistribution(new LossFunctionWrapper(outputActivationFn, lossFunction)); } /** * Activation function for the input to P(z|data).
* Care should be taken with this, as some activation functions (relu, etc) are not suitable due to being * bounded in range [0,infinity). * * @param activationFunction Activation function for p(z|x) */ public Builder pzxActivationFn(IActivation activationFunction) { this.pzxActivationFn = activationFunction; return this; } /** * @deprecated Use {@link #pzxActivationFunction(Activation)} */ @Deprecated public Builder pzxActivationFunction(String activationFunction) { return pzxActivationFn(Activation.fromString(activationFunction).getActivationFunction()); } /** * Activation function for the input to P(z|data).
* Care should be taken with this, as some activation functions (relu, etc) are not suitable due to being * bounded in range [0,infinity). * * @param activation Activation function for p(z|x) */ public Builder pzxActivationFunction(Activation activation) { return pzxActivationFn(activation.getActivationFunction()); } /** * Set the size of the VAE state Z. This is the output size during standard forward pass, and the size of the * distribution P(Z|data) during pretraining. * * @param nOut Size of P(Z|data) and output size */ @Override public Builder nOut(int nOut) { super.nOut(nOut); return this; } /** * Set the number of samples per data point (from VAE state Z) used when doing pretraining. Default value: 1. *

* This is parameter L from Kingma and Welling: "In our experiments we found that the number of samples L per * datapoint can be set to 1 as long as the minibatch size M was large enough, e.g. M = 100." * * @param numSamples Number of samples per data point for pretraining */ public Builder numSamples(int numSamples) { this.numSamples = numSamples; return this; } @Override @SuppressWarnings("unchecked") public VariationalAutoencoder build() { return new VariationalAutoencoder(this); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy