All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.variational.ExponentialReconstructionDistribution Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.conf.layers.variational;

import lombok.Data;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

@Data
public class ExponentialReconstructionDistribution implements ReconstructionDistribution {

    private final IActivation activationFn;

    public ExponentialReconstructionDistribution() {
        this("identity");
    }

    /**
     * @deprecated Use {@link #ExponentialReconstructionDistribution(Activation)}
     */
    @Deprecated
    public ExponentialReconstructionDistribution(String activationFn) {
        this(Activation.fromString(activationFn).getActivationFunction());
    }

    public ExponentialReconstructionDistribution(Activation activation) {
        this(activation.getActivationFunction());
    }

    public ExponentialReconstructionDistribution(IActivation activationFn) {
        this.activationFn = activationFn;
    }

    @Override
    public boolean hasLossFunction() {
        return false;
    }

    @Override
    public int distributionInputSize(int dataSize) {
        return dataSize;
    }

    @Override
    public double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average) {
        //p(x) = lambda * exp( -lambda * x)
        //logp(x) = log(lambda) - lambda * x = gamma - lambda * x

        INDArray gamma = preOutDistributionParams.dup();
        activationFn.getActivation(gamma, false);

        INDArray lambda = Transforms.exp(gamma, true);
        double negLogProbSum = -lambda.muli(x).rsubi(gamma).sumNumber().doubleValue();
        if (average) {
            return negLogProbSum / x.size(0);
        } else {
            return negLogProbSum;
        }

    }

    @Override
    public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {

        INDArray gamma = preOutDistributionParams.dup();
        activationFn.getActivation(gamma, false);

        INDArray lambda = Transforms.exp(gamma, true);
        return lambda.muli(x).rsubi(gamma).sum(true, 1).negi();
    }

    @Override
    public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
        //p(x) = lambda * exp( -lambda * x)
        //logp(x) = log(lambda) - lambda * x = gamma - lambda * x
        //dlogp(x)/dgamma = 1 - lambda * x      (or negative of this for d(-logp(x))/dgamma

        INDArray gamma = activationFn.getActivation(preOutDistributionParams.dup(), true);

        INDArray lambda = Transforms.exp(gamma, true);
        INDArray dLdx = x.mul(lambda).subi(1.0);

        //dL/dz
        return activationFn.backprop(preOutDistributionParams.dup(), dLdx).getFirst();
    }

    @Override
    public INDArray generateRandom(INDArray preOutDistributionParams) {
        INDArray gamma = activationFn.getActivation(preOutDistributionParams.dup(), false);

        INDArray lambda = Transforms.exp(gamma, true);

        //Inverse cumulative distribution function: -log(1-p)/lambda

        INDArray u = Nd4j.rand(preOutDistributionParams.shape());

        //Note here: if u ~ U(0,1) then 1-u ~ U(0,1)
        return Transforms.log(u, false).divi(lambda).negi();
    }

    @Override
    public INDArray generateAtMean(INDArray preOutDistributionParams) {
        //Input: gamma = log(lambda)    ->  lambda = exp(gamma)
        //Mean for exponential distribution: 1/lambda

        INDArray gamma = activationFn.getActivation(preOutDistributionParams.dup(), false);

        INDArray lambda = Transforms.exp(gamma, true);
        return lambda.rdivi(1.0); //mean = 1.0 / lambda
    }

    @Override
    public String toString() {
        return "ExponentialReconstructionDistribution(afn=" + activationFn + ")";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy