org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution Maven / Gradle / Ivy
package org.deeplearning4j.nn.conf.layers.variational;
import lombok.Data;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* CompositeReconstructionDistribution is a reconstruction distribution built from multiple other ReconstructionDistribution
* instances.
* The typical use is to combine for example continuous and binary data in the same model, or to combine different
* distributions for continuous variables. In either case, this class allows users to model (for example) the first 10 values
* as continuous/Gaussian (with a {@link GaussianReconstructionDistribution}, the next 10 values as binary/Bernoulli (with
* a {@link BernoulliReconstructionDistribution})
*
* @author Alex Black
*/
@Data
public class CompositeReconstructionDistribution implements ReconstructionDistribution {
private final int[] distributionSizes;
private final ReconstructionDistribution[] reconstructionDistributions;
private final int totalSize;
public CompositeReconstructionDistribution(@JsonProperty("distributionSizes") int[] distributionSizes,
@JsonProperty("reconstructionDistributions") ReconstructionDistribution[] reconstructionDistributions,
@JsonProperty("totalSize") int totalSize) {
this.distributionSizes = distributionSizes;
this.reconstructionDistributions = reconstructionDistributions;
this.totalSize = totalSize;
}
private CompositeReconstructionDistribution(Builder builder) {
distributionSizes = new int[builder.distributionSizes.size()];
reconstructionDistributions = new ReconstructionDistribution[distributionSizes.length];
int sizeCount = 0;
for (int i = 0; i < distributionSizes.length; i++) {
distributionSizes[i] = builder.distributionSizes.get(i);
reconstructionDistributions[i] = builder.reconstructionDistributions.get(i);
sizeCount += distributionSizes[i];
}
totalSize = sizeCount;
}
public INDArray computeLossFunctionScoreArray(INDArray data, INDArray reconstruction) {
if (!hasLossFunction()) {
throw new IllegalStateException("Cannot compute score array unless hasLossFunction() == true");
}
//Sum the scores from each loss function...
int inputSoFar = 0;
int paramsSoFar = 0;
INDArray reconstructionScores = null;
for (int i = 0; i < distributionSizes.length; i++) {
int thisInputSize = distributionSizes[i];
int thisParamsSize = reconstructionDistributions[i].distributionInputSize(thisInputSize);
INDArray dataSubset =
data.get(NDArrayIndex.all(), NDArrayIndex.interval(inputSoFar, inputSoFar + thisInputSize));
INDArray reconstructionSubset = reconstruction.get(NDArrayIndex.all(),
NDArrayIndex.interval(paramsSoFar, paramsSoFar + thisParamsSize));
if (i == 0) {
reconstructionScores = getScoreArray(reconstructionDistributions[i], dataSubset, reconstructionSubset);
} else {
reconstructionScores
.addi(getScoreArray(reconstructionDistributions[i], dataSubset, reconstructionSubset));
}
inputSoFar += thisInputSize;
paramsSoFar += thisParamsSize;
}
return reconstructionScores;
}
private INDArray getScoreArray(ReconstructionDistribution reconstructionDistribution, INDArray dataSubset,
INDArray reconstructionSubset) {
if (reconstructionDistribution instanceof LossFunctionWrapper) {
ILossFunction lossFunction = ((LossFunctionWrapper) reconstructionDistribution).getLossFunction();
//Re: the activation identity here - the reconstruction array already has the activation function applied,
// so we don't want to apply it again. i.e., we are passing the output, not the pre-output.
return lossFunction.computeScoreArray(dataSubset, reconstructionSubset, new ActivationIdentity(), null);
} else if (reconstructionDistribution instanceof CompositeReconstructionDistribution) {
return ((CompositeReconstructionDistribution) reconstructionDistribution)
.computeLossFunctionScoreArray(dataSubset, reconstructionSubset);
} else {
throw new UnsupportedOperationException("Cannot calculate composite reconstruction distribution");
}
}
@Override
public boolean hasLossFunction() {
for (ReconstructionDistribution rd : reconstructionDistributions) {
if (!rd.hasLossFunction())
return false;
}
return true;
}
@Override
public int distributionInputSize(int dataSize) {
if (dataSize != totalSize) {
throw new IllegalStateException("Invalid input size: Got input size " + dataSize
+ " for data, but expected input" + " size for all distributions is " + totalSize
+ ". Distribution sizes: " + Arrays.toString(distributionSizes));
}
int sum = 0;
for (int i = 0; i < distributionSizes.length; i++) {
sum += reconstructionDistributions[i].distributionInputSize(distributionSizes[i]);
}
return sum;
}
@Override
public double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average) {
int inputSoFar = 0;
int paramsSoFar = 0;
double logProbSum = 0.0;
for (int i = 0; i < distributionSizes.length; i++) {
int thisInputSize = distributionSizes[i];
int thisParamsSize = reconstructionDistributions[i].distributionInputSize(thisInputSize);
INDArray inputSubset =
x.get(NDArrayIndex.all(), NDArrayIndex.interval(inputSoFar, inputSoFar + thisInputSize));
INDArray paramsSubset = preOutDistributionParams.get(NDArrayIndex.all(),
NDArrayIndex.interval(paramsSoFar, paramsSoFar + thisParamsSize));
logProbSum += reconstructionDistributions[i].negLogProbability(inputSubset, paramsSubset, average);
inputSoFar += thisInputSize;
paramsSoFar += thisParamsSize;
}
return logProbSum;
}
@Override
public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
int inputSoFar = 0;
int paramsSoFar = 0;
INDArray exampleLogProbSum = null;
for (int i = 0; i < distributionSizes.length; i++) {
int thisInputSize = distributionSizes[i];
int thisParamsSize = reconstructionDistributions[i].distributionInputSize(thisInputSize);
INDArray inputSubset =
x.get(NDArrayIndex.all(), NDArrayIndex.interval(inputSoFar, inputSoFar + thisInputSize));
INDArray paramsSubset = preOutDistributionParams.get(NDArrayIndex.all(),
NDArrayIndex.interval(paramsSoFar, paramsSoFar + thisParamsSize));
if (i == 0) {
exampleLogProbSum = reconstructionDistributions[i].exampleNegLogProbability(inputSubset, paramsSubset);
} else {
exampleLogProbSum.addi(
reconstructionDistributions[i].exampleNegLogProbability(inputSubset, paramsSubset));
}
inputSoFar += thisInputSize;
paramsSoFar += thisParamsSize;
}
return exampleLogProbSum;
}
@Override
public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
int inputSoFar = 0;
int paramsSoFar = 0;
INDArray gradient = Nd4j.createUninitialized(preOutDistributionParams.shape());
for (int i = 0; i < distributionSizes.length; i++) {
int thisInputSize = distributionSizes[i];
int thisParamsSize = reconstructionDistributions[i].distributionInputSize(thisInputSize);
INDArray inputSubset =
x.get(NDArrayIndex.all(), NDArrayIndex.interval(inputSoFar, inputSoFar + thisInputSize));
INDArray paramsSubset = preOutDistributionParams.get(NDArrayIndex.all(),
NDArrayIndex.interval(paramsSoFar, paramsSoFar + thisParamsSize));
INDArray grad = reconstructionDistributions[i].gradient(inputSubset, paramsSubset);
gradient.put(new INDArrayIndex[] {NDArrayIndex.all(),
NDArrayIndex.interval(paramsSoFar, paramsSoFar + thisParamsSize)}, grad);
inputSoFar += thisInputSize;
paramsSoFar += thisParamsSize;
}
return gradient;
}
@Override
public INDArray generateRandom(INDArray preOutDistributionParams) {
return randomSample(preOutDistributionParams, false);
}
@Override
public INDArray generateAtMean(INDArray preOutDistributionParams) {
return randomSample(preOutDistributionParams, true);
}
private INDArray randomSample(INDArray preOutDistributionParams, boolean isMean) {
int inputSoFar = 0;
int paramsSoFar = 0;
INDArray out = Nd4j.createUninitialized(new int[] {preOutDistributionParams.size(0), totalSize});
for (int i = 0; i < distributionSizes.length; i++) {
int thisDataSize = distributionSizes[i];
int thisParamsSize = reconstructionDistributions[i].distributionInputSize(thisDataSize);
INDArray paramsSubset = preOutDistributionParams.get(NDArrayIndex.all(),
NDArrayIndex.interval(paramsSoFar, paramsSoFar + thisParamsSize));
INDArray thisRandomSample;
if (isMean) {
thisRandomSample = reconstructionDistributions[i].generateAtMean(paramsSubset);
} else {
thisRandomSample = reconstructionDistributions[i].generateRandom(paramsSubset);
}
out.put(new INDArrayIndex[] {NDArrayIndex.all(),
NDArrayIndex.interval(inputSoFar, inputSoFar + thisDataSize)}, thisRandomSample);
inputSoFar += thisDataSize;
paramsSoFar += thisParamsSize;
}
return out;
}
public static class Builder {
private List distributionSizes = new ArrayList<>();
private List reconstructionDistributions = new ArrayList<>();
/**
* Add another distribution to the composite distribution. This will add the distribution for the next 'distributionSize'
* values, after any previously added.
* For example, calling addDistribution(10, X) once will result in values 0 to 9 (inclusive) being modelled
* by the specified distribution X. Calling addDistribution(10, Y) after that will result in values 10 to 19 (inclusive)
* being modelled by distribution Y.
*
* @param distributionSize Number of values to model with the specified distribution
* @param distribution Distribution to model data with
*/
public Builder addDistribution(int distributionSize, ReconstructionDistribution distribution) {
distributionSizes.add(distributionSize);
reconstructionDistributions.add(distribution);
return this;
}
public CompositeReconstructionDistribution build() {
return new CompositeReconstructionDistribution(this);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy