All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.KullbackLeiblerDivergence Maven / Gradle / Ivy

/*
 * File:                KullbackLeiblerDivergence.java
 * Authors:             Tom Brounstein
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright March 31, 2014, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.statistics;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import java.util.Collections;
import java.util.Set;

/**
 * A class used for measuring how similar two distributions are using Kullback--Leibler Divergence.
 * @author trbroun
 * @since  3.4.2
 * @param   The type for the domain of the two distributions which this class is comparing.
 */
@PublicationReference(
    author="Wikipedia",
    title="Kullback--Leibler Divergence",
    type=PublicationType.WebPage,
    year=2014,
    url="http://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence"
)
public class KullbackLeiblerDivergence
{

    private final DiscreteDistribution firstDistribution;
    private final DiscreteDistribution secondDistribution;
    
    /**
     * Basic constructor to find the Kullback--Leibler Divergence between the two supplied distributions.
     * The two distributions must be over the same domain.
     * @param firstDistribution
     * @param secondDistribution 
     */
    @SuppressWarnings("unchecked")
    public KullbackLeiblerDivergence(DiscreteDistribution firstDistribution,
        DiscreteDistribution secondDistribution)
    {
        if (firstDistribution == null) {
            throw new IllegalArgumentException("First distribution is null.");
        }
        if (secondDistribution == null) {
            throw new IllegalArgumentException("Second distribution is null.");
        }
        for (DomainType term : firstDistribution.getDomain()) {
            double temp = firstDistribution.getProbabilityFunction().evaluate(term);
            if (temp != 0 && (!secondDistribution.getDomain().contains(term) ||
                secondDistribution.getProbabilityFunction().evaluate(term)==0)) {
                throw new IllegalArgumentException("Domain mismatch; a non-zero probability in first distribution requires a non-zero probability in second distribution");
            }
        }        
        this.firstDistribution = (DiscreteDistribution) firstDistribution.clone();
        this.secondDistribution = (DiscreteDistribution) secondDistribution.clone();
    }

    /**
     * Gets the domain of the distributions.
     * @return The domain of the distributions.
     */
    public Set getDomain()
    {
        return Collections.unmodifiableSet(secondDistribution.getDomain());
    }

    /**
     * Computes the Kullback--Leibler Divergence.
     * @return The divergence value.
     */
    public double compute()
    {
        double sum = 0.0;
        Set domain = firstDistribution.getDomain();
        ProbabilityMassFunction pmfP =
            firstDistribution.getProbabilityFunction();
        ProbabilityMassFunction pmfQ = 
            secondDistribution.getProbabilityFunction();

        for (DomainType element : domain)
        {
            double PTerm = pmfP.evaluate(element);
            double QTerm = pmfQ.evaluate(element);
            if (QTerm == 0 || PTerm == 0) {
                continue;
            }
            double temp = PTerm/QTerm;
            temp = Math.log(temp);
            double termSolution = temp*PTerm;
            sum += termSolution;
        }
        
        return sum;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy