All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.nn.conf.layers.variational;

import lombok.Data;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/**
 * LossFunctionWrapper allows training of a VAE model with a standard (possibly deterministic) neural network loss function
 * for the reconstruction, instead of using a {@link ReconstructionDistribution} as would normally be done with a VAE model.
 * 

* Note: most functionality is supported, but clearly reconstruction log probability cannot be calculated when using * LossFunctionWrapper, as ILossFunction instances do not have either (a) a probabilistic interpretation, or (b) a * means of calculating the negative log probability. * * @author Alex Black */ @Data public class LossFunctionWrapper implements ReconstructionDistribution { private final IActivation activationFn; private final ILossFunction lossFunction; public LossFunctionWrapper(@JsonProperty("activationFn") IActivation activationFn, @JsonProperty("lossFunction") ILossFunction lossFunction) { this.activationFn = activationFn; this.lossFunction = lossFunction; } public LossFunctionWrapper(Activation activation, ILossFunction lossFunction) { this(activation.getActivationFunction(), lossFunction); } @Override public boolean hasLossFunction() { return true; } @Override public int distributionInputSize(int dataSize) { return dataSize; } @Override public double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average) { //NOTE: The returned value here is NOT negative log probability, but it (the loss function value) // is equivalent, in terms of being something we want to minimize... return lossFunction.computeScore(x, preOutDistributionParams, activationFn, null, average); } @Override public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) { return lossFunction.computeScoreArray(x, preOutDistributionParams, activationFn, null); } @Override public INDArray gradient(INDArray x, INDArray preOutDistributionParams) { return lossFunction.computeGradient(x, preOutDistributionParams, activationFn, null); } @Override public INDArray generateRandom(INDArray preOutDistributionParams) { //Loss functions: not probabilistic -> deterministic output return generateAtMean(preOutDistributionParams); } @Override public INDArray generateAtMean(INDArray preOutDistributionParams) { //Loss functions: not probabilistic -> not random INDArray out = preOutDistributionParams.dup(); return activationFn.getActivation(out, true); } @Override public String toString() { return "LossFunctionWrapper(afn=" + activationFn + "," + lossFunction + ")"; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy