All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.conf.layers.variational;

import lombok.Data;
import lombok.val;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

@Data
public class GaussianReconstructionDistribution implements ReconstructionDistribution {

    private static final double NEG_HALF_LOG_2PI = -0.5 * Math.log(2 * Math.PI);

    private final IActivation activationFn;

    /**
     * Create a GaussianReconstructionDistribution with the default identity activation function.
     */
    public GaussianReconstructionDistribution() {
        this(Activation.IDENTITY);
    }

    /**
     * @param activationFn    Activation function for the reconstruction distribution. Typically identity or tanh.
     */
    public GaussianReconstructionDistribution(Activation activationFn) {
        this(activationFn.getActivationFunction());
    }

    /**
     * @param activationFn    Activation function for the reconstruction distribution. Typically identity or tanh.
     */
    public GaussianReconstructionDistribution(IActivation activationFn) {
        this.activationFn = activationFn;
    }

    @Override
    public boolean hasLossFunction() {
        return false;
    }

    @Override
    public int distributionInputSize(int dataSize) {
        return 2 * dataSize;
    }

    @Override
    public double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average) {
        val size = preOutDistributionParams.size(1) / 2;

        INDArray[] logProbArrays = calcLogProbArrayExConstants(x, preOutDistributionParams);
        double logProb = x.size(0) * size * NEG_HALF_LOG_2PI - 0.5 * logProbArrays[0].sumNumber().doubleValue()
                        - logProbArrays[1].sumNumber().doubleValue();

        if (average) {
            return -logProb / x.size(0);
        } else {
            return -logProb;
        }
    }

    @Override
    public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
        val size = preOutDistributionParams.size(1) / 2;

        INDArray[] logProbArrays = calcLogProbArrayExConstants(x, preOutDistributionParams);

        return logProbArrays[0].sum(true, 1).muli(0.5).subi(size * NEG_HALF_LOG_2PI)
                        .addi(logProbArrays[1].sum(true, 1));
    }

    private INDArray[] calcLogProbArrayExConstants(INDArray x, INDArray preOutDistributionParams) {
        INDArray output = preOutDistributionParams.dup();
        activationFn.getActivation(output, false);

        val size = output.size(1) / 2;
        INDArray mean = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size));
        INDArray logStdevSquared = output.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size));

        INDArray sigmaSquared = Transforms.exp(logStdevSquared, true);
        INDArray lastTerm = x.sub(mean.castTo(x.dataType()));
        lastTerm.muli(lastTerm);
        lastTerm.divi(sigmaSquared.castTo(lastTerm.dataType())).divi(2);

        return new INDArray[] {logStdevSquared, lastTerm};
    }

    @Override
    public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
        INDArray output = preOutDistributionParams.dup();
        activationFn.getActivation(output, true);

        val size = output.size(1) / 2;
        INDArray mean = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size));
        INDArray logStdevSquared = output.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size));

        INDArray sigmaSquared = Transforms.exp(logStdevSquared, true).castTo(x.dataType());

        INDArray xSubMean = x.sub(mean.castTo(x.dataType()));
        INDArray xSubMeanSq = xSubMean.mul(xSubMean);

        INDArray dLdmu = xSubMean.divi(sigmaSquared);

        INDArray sigma = Transforms.sqrt(sigmaSquared, true);
        INDArray sigma3 = Transforms.pow(sigmaSquared, 3.0 / 2);

        INDArray dLdsigma = sigma.rdiv(-1).addi(xSubMeanSq.divi(sigma3));
        INDArray dLdlogSigma2 = sigma.divi(2).muli(dLdsigma);

        INDArray dLdx = Nd4j.createUninitialized(preOutDistributionParams.dataType(), output.shape());
        dLdx.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, dLdmu);
        dLdx.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, dLdlogSigma2);
        dLdx.negi();

        //dL/dz
        return activationFn.backprop(preOutDistributionParams.dup(), dLdx).getFirst();
    }

    @Override
    public INDArray generateRandom(INDArray preOutDistributionParams) {
        INDArray output = preOutDistributionParams.dup();
        activationFn.getActivation(output, true);

        val size = output.size(1) / 2;
        INDArray mean = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size));
        INDArray logStdevSquared = output.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size));

        INDArray sigma = Transforms.exp(logStdevSquared, true);
        Transforms.sqrt(sigma, false);

        INDArray e = Nd4j.randn(sigma.shape());
        return e.muli(sigma).addi(mean); //mu + sigma * N(0,1) ~ N(mu,sigma^2)
    }

    @Override
    public INDArray generateAtMean(INDArray preOutDistributionParams) {
        val size = preOutDistributionParams.size(1) / 2;
        INDArray mean = preOutDistributionParams.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size)).dup();
        activationFn.getActivation(mean, false);

        return mean;
    }

    @Override
    public String toString() {
        return "GaussianReconstructionDistribution(afn=" + activationFn + ")";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy