gov.sandia.cognition.statistics.KullbackLeiblerDivergence Maven / Gradle / Ivy
/*
* File: KullbackLeiblerDivergence.java
* Authors: Tom Brounstein
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright March 31, 2014, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import java.util.Collections;
import java.util.Set;
/**
* A class used for measuring how similar two distributions are using Kullback--Leibler Divergence.
* @author trbroun
* @since 3.4.2
* @param The type for the domain of the two distributions which this class is comparing.
*/
@PublicationReference(
author="Wikipedia",
title="Kullback--Leibler Divergence",
type=PublicationType.WebPage,
year=2014,
url="http://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence"
)
public class KullbackLeiblerDivergence
{
private final DiscreteDistribution firstDistribution;
private final DiscreteDistribution secondDistribution;
/**
* Basic constructor to find the Kullback--Leibler Divergence between the two supplied distributions.
* The two distributions must be over the same domain.
* @param firstDistribution
* @param secondDistribution
*/
@SuppressWarnings("unchecked")
public KullbackLeiblerDivergence(DiscreteDistribution firstDistribution,
DiscreteDistribution secondDistribution)
{
if (firstDistribution == null) {
throw new IllegalArgumentException("First distribution is null.");
}
if (secondDistribution == null) {
throw new IllegalArgumentException("Second distribution is null.");
}
for (DomainType term : firstDistribution.getDomain()) {
double temp = firstDistribution.getProbabilityFunction().evaluate(term);
if (temp != 0 && (!secondDistribution.getDomain().contains(term) ||
secondDistribution.getProbabilityFunction().evaluate(term)==0)) {
throw new IllegalArgumentException("Domain mismatch; a non-zero probability in first distribution requires a non-zero probability in second distribution");
}
}
this.firstDistribution = (DiscreteDistribution) firstDistribution.clone();
this.secondDistribution = (DiscreteDistribution) secondDistribution.clone();
}
/**
* Gets the domain of the distributions.
* @return The domain of the distributions.
*/
public Set extends DomainType> getDomain()
{
return Collections.unmodifiableSet(secondDistribution.getDomain());
}
/**
* Computes the Kullback--Leibler Divergence.
* @return The divergence value.
*/
public double compute()
{
double sum = 0.0;
Set extends DomainType> domain = firstDistribution.getDomain();
ProbabilityMassFunction pmfP =
firstDistribution.getProbabilityFunction();
ProbabilityMassFunction pmfQ =
secondDistribution.getProbabilityFunction();
for (DomainType element : domain)
{
double PTerm = pmfP.evaluate(element);
double QTerm = pmfQ.evaluate(element);
if (QTerm == 0 || PTerm == 0) {
continue;
}
double temp = PTerm/QTerm;
temp = Math.log(temp);
double termSolution = temp*PTerm;
sum += termSolution;
}
return sum;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy