All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.layers.variational;
import lombok.Data;
import lombok.val;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
@Data
public class GaussianReconstructionDistribution implements ReconstructionDistribution {
private static final double NEG_HALF_LOG_2PI = -0.5 * Math.log(2 * Math.PI);
private final IActivation activationFn;
/**
* Create a GaussianReconstructionDistribution with the default identity activation function.
*/
public GaussianReconstructionDistribution() {
this(Activation.IDENTITY);
}
/**
* @param activationFn Activation function for the reconstruction distribution. Typically identity or tanh.
*/
public GaussianReconstructionDistribution(Activation activationFn) {
this(activationFn.getActivationFunction());
}
/**
* @param activationFn Activation function for the reconstruction distribution. Typically identity or tanh.
*/
public GaussianReconstructionDistribution(IActivation activationFn) {
this.activationFn = activationFn;
}
@Override
public boolean hasLossFunction() {
return false;
}
@Override
public int distributionInputSize(int dataSize) {
return 2 * dataSize;
}
@Override
public double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average) {
val size = preOutDistributionParams.size(1) / 2;
INDArray[] logProbArrays = calcLogProbArrayExConstants(x, preOutDistributionParams);
double logProb = x.size(0) * size * NEG_HALF_LOG_2PI - 0.5 * logProbArrays[0].sumNumber().doubleValue()
- logProbArrays[1].sumNumber().doubleValue();
if (average) {
return -logProb / x.size(0);
} else {
return -logProb;
}
}
@Override
public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
val size = preOutDistributionParams.size(1) / 2;
INDArray[] logProbArrays = calcLogProbArrayExConstants(x, preOutDistributionParams);
return logProbArrays[0].sum(true, 1).muli(0.5).subi(size * NEG_HALF_LOG_2PI)
.addi(logProbArrays[1].sum(true, 1));
}
private INDArray[] calcLogProbArrayExConstants(INDArray x, INDArray preOutDistributionParams) {
INDArray output = preOutDistributionParams.dup();
activationFn.getActivation(output, false);
val size = output.size(1) / 2;
INDArray mean = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size));
INDArray logStdevSquared = output.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size));
INDArray sigmaSquared = Transforms.exp(logStdevSquared, true);
INDArray lastTerm = x.sub(mean.castTo(x.dataType()));
lastTerm.muli(lastTerm);
lastTerm.divi(sigmaSquared.castTo(lastTerm.dataType())).divi(2);
return new INDArray[] {logStdevSquared, lastTerm};
}
@Override
public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
INDArray output = preOutDistributionParams.dup();
activationFn.getActivation(output, true);
val size = output.size(1) / 2;
INDArray mean = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size));
INDArray logStdevSquared = output.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size));
INDArray sigmaSquared = Transforms.exp(logStdevSquared, true).castTo(x.dataType());
INDArray xSubMean = x.sub(mean.castTo(x.dataType()));
INDArray xSubMeanSq = xSubMean.mul(xSubMean);
INDArray dLdmu = xSubMean.divi(sigmaSquared);
INDArray sigma = Transforms.sqrt(sigmaSquared, true);
INDArray sigma3 = Transforms.pow(sigmaSquared, 3.0 / 2);
INDArray dLdsigma = sigma.rdiv(-1).addi(xSubMeanSq.divi(sigma3));
INDArray dLdlogSigma2 = sigma.divi(2).muli(dLdsigma);
INDArray dLdx = Nd4j.createUninitialized(preOutDistributionParams.dataType(), output.shape());
dLdx.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, size)}, dLdmu);
dLdx.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, dLdlogSigma2);
dLdx.negi();
//dL/dz
return activationFn.backprop(preOutDistributionParams.dup(), dLdx).getFirst();
}
@Override
public INDArray generateRandom(INDArray preOutDistributionParams) {
INDArray output = preOutDistributionParams.dup();
activationFn.getActivation(output, true);
val size = output.size(1) / 2;
INDArray mean = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size));
INDArray logStdevSquared = output.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size));
INDArray sigma = Transforms.exp(logStdevSquared, true);
Transforms.sqrt(sigma, false);
INDArray e = Nd4j.randn(sigma.shape());
return e.muli(sigma).addi(mean); //mu + sigma * N(0,1) ~ N(mu,sigma^2)
}
@Override
public INDArray generateAtMean(INDArray preOutDistributionParams) {
val size = preOutDistributionParams.size(1) / 2;
INDArray mean = preOutDistributionParams.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size)).dup();
activationFn.getActivation(mean, false);
return mean;
}
@Override
public String toString() {
return "GaussianReconstructionDistribution(afn=" + activationFn + ")";
}
}