All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.conf.layers.variational;

import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * CompositeReconstructionDistribution is a reconstruction distribution built from multiple other ReconstructionDistribution
 * instances.
* The typical use is to combine for example continuous and binary data in the same model, or to combine different * distributions for continuous variables. In either case, this class allows users to model (for example) the first 10 values * as continuous/Gaussian (with a {@link GaussianReconstructionDistribution}, the next 10 values as binary/Bernoulli (with * a {@link BernoulliReconstructionDistribution}) * * @author Alex Black */ @Data public class CompositeReconstructionDistribution implements ReconstructionDistribution { private final int[] distributionSizes; private final ReconstructionDistribution[] reconstructionDistributions; private final int totalSize; public CompositeReconstructionDistribution(@JsonProperty("distributionSizes") int[] distributionSizes, @JsonProperty("reconstructionDistributions") ReconstructionDistribution[] reconstructionDistributions, @JsonProperty("totalSize") int totalSize){ this.distributionSizes = distributionSizes; this.reconstructionDistributions = reconstructionDistributions; this.totalSize = totalSize; } private CompositeReconstructionDistribution(Builder builder){ distributionSizes = new int[builder.distributionSizes.size()]; reconstructionDistributions = new ReconstructionDistribution[distributionSizes.length]; int sizeCount = 0; for( int i=0; i< distributionSizes.length; i++ ){ distributionSizes[i] = builder.distributionSizes.get(i); reconstructionDistributions[i] = builder.reconstructionDistributions.get(i); sizeCount += distributionSizes[i]; } totalSize = sizeCount; } @Override public int distributionInputSize(int dataSize) { if(dataSize != totalSize){ throw new IllegalStateException("Invalid input size: Got input size " + dataSize + " for data, but expected input" + " size for all distributions is " + totalSize + ". Distribution sizes: " + Arrays.toString(distributionSizes)); } int sum = 0; for( int i=0; i distributionSizes = new ArrayList<>(); private List reconstructionDistributions = new ArrayList<>(); /** * Add another distribution to the composite distribution. This will add the distribution for the next 'distributionSize' * values, after any previously added. * For example, calling addDistribution(10, X) once will result in values 0 to 9 (inclusive) being modelled * by the specified distribution X. Calling addDistribution(10, Y) after that will result in values 10 to 19 (inclusive) * being modelled by distribution Y. * * @param distributionSize Number of values to model with the specified distribution * @param distribution Distribution to model data with */ public Builder addDistribution(int distributionSize, ReconstructionDistribution distribution){ distributionSizes.add(distributionSize); reconstructionDistributions.add(distribution); return this; } public CompositeReconstructionDistribution build(){ return new CompositeReconstructionDistribution(this); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy