All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.conf.layers.variational;

import lombok.*;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.LayerValidation;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.VariationalAutoencoderParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.util.ArrayUtil;

import java.util.Collection;
import java.util.Map;

@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
public class VariationalAutoencoder extends BasePretrainNetwork {

    private int[] encoderLayerSizes;
    private int[] decoderLayerSizes;
    private ReconstructionDistribution outputDistribution;
    private IActivation pzxActivationFn;
    private int numSamples;

    private VariationalAutoencoder(Builder builder) {
        super(builder);
        this.encoderLayerSizes = builder.encoderLayerSizes;
        this.decoderLayerSizes = builder.decoderLayerSizes;
        this.outputDistribution = builder.outputDistribution;
        this.pzxActivationFn = builder.pzxActivationFn;
        this.numSamples = builder.numSamples;
    }

    @Override
    public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners,
                             int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
        LayerValidation.assertNInNOutSet("VariationalAutoencoder", getLayerName(), layerIndex, getNIn(), getNOut());

        org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret =
                        new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(conf, networkDataType);

        ret.setListeners(trainingListeners);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map paramTable = initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public ParamInitializer initializer() {
        return VariationalAutoencoderParamInitializer.getInstance();
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        if (paramName.startsWith(VariationalAutoencoderParamInitializer.DECODER_PREFIX)) {
            return true;
        }
        if (paramName.startsWith(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_PREFIX)) {
            return true;
        }
        if (paramName.startsWith(VariationalAutoencoderParamInitializer.PXZ_PREFIX)) {
            return true;
        }
        return false;
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        //For training: we'll assume unsupervised pretraining, as this has higher memory requirements

        InputType outputType = getOutputType(-1, inputType);

        val actElementsPerEx = outputType.arrayElementsPerExample();
        val numParams = initializer().numParams(this);
        int updaterStateSize = (int) getIUpdater().stateSize(numParams);

        int inferenceWorkingMemSizePerEx = 0;
        //Forward pass size through the encoder:
        for (int i = 1; i < encoderLayerSizes.length; i++) {
            inferenceWorkingMemSizePerEx += encoderLayerSizes[i];
        }

        //Forward pass size through the decoder, during training
        //p(Z|X) mean and stdev; pzxSigmaSquared, pzxSigma -> all size equal to nOut
        long decoderFwdSizeWorking = 4 * nOut;
        //plus, nSamples * decoder size
        //For each decoding: random sample (nOut), z (nOut), activations for each decoder layer
        decoderFwdSizeWorking += numSamples * (2 * nOut + ArrayUtil.sum(getDecoderLayerSizes()));
        //Plus, component of score
        decoderFwdSizeWorking += nOut;

        //Backprop size through the decoder and decoder: approx. 2x forward pass size
        long trainWorkingMemSize = 2 * (inferenceWorkingMemSizePerEx + decoderFwdSizeWorking);

        if (getIDropout() != null) {
            if (false) {
                //TODO drop connect
                //Dup the weights... note that this does NOT depend on the minibatch size...
            } else {
                //Assume we dup the input
                trainWorkingMemSize += inputType.arrayElementsPerExample();
            }
        }

        return new LayerMemoryReport.Builder(layerName, VariationalAutoencoder.class, inputType, outputType)
                        .standardMemory(numParams, updaterStateSize)
                        .workingMemory(0, inferenceWorkingMemSizePerEx, 0, trainWorkingMemSize)
                        .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
                        .build();
    }

    @Getter
    @Setter
    public static class Builder extends BasePretrainNetwork.Builder {

        /**
         * Size of the encoder layers, in units. Each encoder layer is functionally equivalent to a {@link
         * org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers (set via
         * {@link #decoderLayerSizes(int...)} is similar to the encoder layers.
         *
         */
        private int[] encoderLayerSizes = new int[] {100};

        /**
         * Size of the decoder layers, in units. Each decoder layer is functionally equivalent to a {@link
         * org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers is similar
         * to the encoder layers (set via {@link #encoderLayerSizes(int...)}.
         *
         */
        private int[] decoderLayerSizes = new int[] {100};
        /**
         * The reconstruction distribution for the data given the hidden state - i.e., P(data|Z).
This should be * selected carefully based on the type of data being modelled. For example:
- {@link * GaussianReconstructionDistribution} + {identity or tanh} for real-valued (Gaussian) data
- {@link * BernoulliReconstructionDistribution} + sigmoid for binary-valued (0 or 1) data
* */ private ReconstructionDistribution outputDistribution = new GaussianReconstructionDistribution(Activation.TANH); /** * Activation function for the input to P(z|data).
Care should be taken with this, as some activation * functions (relu, etc) are not suitable due to being bounded in range [0,infinity). * */ private IActivation pzxActivationFn = new ActivationIdentity(); /** * Set the number of samples per data point (from VAE state Z) used when doing pretraining. Default value: 1. *

* This is parameter L from Kingma and Welling: "In our experiments we found that the number of samples L per * datapoint can be set to 1 as long as the minibatch size M was large enough, e.g. M = 100." * */ private int numSamples = 1; /** * Size of the encoder layers, in units. Each encoder layer is functionally equivalent to a {@link * org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers (set via * {@link #decoderLayerSizes(int...)} is similar to the encoder layers. * * @param encoderLayerSizes Size of each encoder layer in the variational autoencoder */ public Builder encoderLayerSizes(int... encoderLayerSizes) { this.setEncoderLayerSizes(encoderLayerSizes); return this; } /** * Size of the encoder layers, in units. Each encoder layer is functionally equivalent to a {@link * org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers (set via * {@link #decoderLayerSizes(int...)} is similar to the encoder layers. * * @param encoderLayerSizes Size of each encoder layer in the variational autoencoder */ public void setEncoderLayerSizes(int... encoderLayerSizes) { if (encoderLayerSizes == null || encoderLayerSizes.length < 1) { throw new IllegalArgumentException("Encoder layer sizes array must have length > 0"); } this.encoderLayerSizes = encoderLayerSizes; } /** * Size of the decoder layers, in units. Each decoder layer is functionally equivalent to a {@link * org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers is similar * to the encoder layers (set via {@link #encoderLayerSizes(int...)}. * * @param decoderLayerSizes Size of each deccoder layer in the variational autoencoder */ public Builder decoderLayerSizes(int... decoderLayerSizes) { this.setDecoderLayerSizes(decoderLayerSizes); return this; } /** * Size of the decoder layers, in units. Each decoder layer is functionally equivalent to a {@link * org.deeplearning4j.nn.conf.layers.DenseLayer}. Typically the number and size of the decoder layers is similar * to the encoder layers (set via {@link #encoderLayerSizes(int...)}. * * @param decoderLayerSizes Size of each deccoder layer in the variational autoencoder */ public void setDecoderLayerSizes(int... decoderLayerSizes) { if (decoderLayerSizes == null || decoderLayerSizes.length < 1) { throw new IllegalArgumentException("Decoder layer sizes array must have length > 0"); } this.decoderLayerSizes = decoderLayerSizes; } /** * The reconstruction distribution for the data given the hidden state - i.e., P(data|Z).
This should be * selected carefully based on the type of data being modelled. For example:
- {@link * GaussianReconstructionDistribution} + {identity or tanh} for real-valued (Gaussian) data
- {@link * BernoulliReconstructionDistribution} + sigmoid for binary-valued (0 or 1) data
* * @param distribution Reconstruction distribution */ public Builder reconstructionDistribution(ReconstructionDistribution distribution) { this.setOutputDistribution(distribution); return this; } /** * Configure the VAE to use the specified loss function for the reconstruction, instead of a * ReconstructionDistribution. Note that this is NOT following the standard VAE design (as per Kingma & * Welling), which assumes a probabilistic output - i.e., some p(x|z). It is however a valid network * configuration, allowing for optimization of more traditional objectives such as mean squared error.
Note: * clearly, setting the loss function here will override any previously set recontruction distribution * * @param outputActivationFn Activation function for the output/reconstruction * @param lossFunction Loss function to use */ public Builder lossFunction(IActivation outputActivationFn, LossFunctions.LossFunction lossFunction) { return lossFunction(outputActivationFn, lossFunction.getILossFunction()); } /** * Configure the VAE to use the specified loss function for the reconstruction, instead of a * ReconstructionDistribution. Note that this is NOT following the standard VAE design (as per Kingma & * Welling), which assumes a probabilistic output - i.e., some p(x|z). It is however a valid network * configuration, allowing for optimization of more traditional objectives such as mean squared error.
Note: * clearly, setting the loss function here will override any previously set recontruction distribution * * @param outputActivationFn Activation function for the output/reconstruction * @param lossFunction Loss function to use */ public Builder lossFunction(Activation outputActivationFn, LossFunctions.LossFunction lossFunction) { return lossFunction(outputActivationFn.getActivationFunction(), lossFunction.getILossFunction()); } /** * Configure the VAE to use the specified loss function for the reconstruction, instead of a * ReconstructionDistribution. Note that this is NOT following the standard VAE design (as per Kingma & * Welling), which assumes a probabilistic output - i.e., some p(x|z). It is however a valid network * configuration, allowing for optimization of more traditional objectives such as mean squared error.
Note: * clearly, setting the loss function here will override any previously set recontruction distribution * * @param outputActivationFn Activation function for the output/reconstruction * @param lossFunction Loss function to use */ public Builder lossFunction(IActivation outputActivationFn, ILossFunction lossFunction) { return reconstructionDistribution(new LossFunctionWrapper(outputActivationFn, lossFunction)); } /** * Activation function for the input to P(z|data).
Care should be taken with this, as some activation * functions (relu, etc) are not suitable due to being bounded in range [0,infinity). * * @param activationFunction Activation function for p(z|x) */ public Builder pzxActivationFn(IActivation activationFunction) { this.setPzxActivationFn(activationFunction); return this; } /** * Activation function for the input to P(z|data).
Care should be taken with this, as some activation * functions (relu, etc) are not suitable due to being bounded in range [0,infinity). * * @param activation Activation function for p(z|x) */ public Builder pzxActivationFunction(Activation activation) { return pzxActivationFn(activation.getActivationFunction()); } /** * Set the size of the VAE state Z. This is the output size during standard forward pass, and the size of the * distribution P(Z|data) during pretraining. * * @param nOut Size of P(Z|data) and output size */ @Override public Builder nOut(int nOut) { super.nOut(nOut); return this; } /** * Set the number of samples per data point (from VAE state Z) used when doing pretraining. Default value: 1. *

* This is parameter L from Kingma and Welling: "In our experiments we found that the number of samples L per * datapoint can be set to 1 as long as the minibatch size M was large enough, e.g. M = 100." * * @param numSamples Number of samples per data point for pretraining */ public Builder numSamples(int numSamples) { this.setNumSamples(numSamples); return this; } @Override @SuppressWarnings("unchecked") public VariationalAutoencoder build() { return new VariationalAutoencoder(this); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy