All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.method.MaximumLikelihoodDistributionEstimator Maven / Gradle / Ivy

/*
 * File:                MaximumLikelihoodDistributionEstimator.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Jul 12, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.statistics.method;

import gov.sandia.cognition.algorithm.AbstractParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerDirectionSetPowell;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerNelderMead;
import gov.sandia.cognition.learning.algorithm.minimization.line.LineMinimizerDerivativeFree;
import gov.sandia.cognition.learning.algorithm.minimization.line.interpolator.LineBracketInterpolatorBrent;
import gov.sandia.cognition.learning.function.cost.NegativeLogLikelihood;
import gov.sandia.cognition.learning.function.cost.ParallelNegativeLogLikelihood;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorReader;
import gov.sandia.cognition.statistics.ClosedFormComputableDistribution;
import gov.sandia.cognition.statistics.ClosedFormDiscreteUnivariateDistribution;
import gov.sandia.cognition.statistics.DistributionEstimator;
import gov.sandia.cognition.statistics.EstimableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.SmoothUnivariateDistribution;
import gov.sandia.cognition.statistics.distribution.BetaBinomialDistribution;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;

/**
 * Estimates the most-likely distribution, and corresponding parameters, of
 * that generated the given data from a pre-determined collection of
 * candidate parameteric distributions.
 * @param  Type of data generated by the distributions.
 * @author Kevin R. Dixon
 * @since 3.1
 */
public class MaximumLikelihoodDistributionEstimator
    extends AbstractParallelAlgorithm
    implements BatchLearner,ClosedFormComputableDistribution>
{

    /**
     * Collection of Distributions to estimate the optimal parameters of
     */
    private Collection> distributions;

    /** 
     * Creates a new instance of MaximumLikelihoodDistributionEstimator
     */
    public MaximumLikelihoodDistributionEstimator()
    {
        this( null );
    }

    /**
     * Creates a new instance of MaximumLikelihoodDistributionEstimator
     * @param distributions
     * Collection of Distributions to estimate the optimal parameters of
     */
    public MaximumLikelihoodDistributionEstimator(
        Collection> distributions )
    {
        this.setDistributions(distributions);
    }

    @Override
    public MaximumLikelihoodDistributionEstimator clone()
    {
        @SuppressWarnings("unchecked")
        MaximumLikelihoodDistributionEstimator clone =
            (MaximumLikelihoodDistributionEstimator) super.clone();
        clone.setDistributions( ObjectUtil.cloneSmartElementsAsArrayList(
            this.getDistributions() ) );
        return clone;
    }

    /**
     * Getter for distributions
     * @return
     * Collection of Distributions to estimate the optimal parameters of
     */
    public Collection> getDistributions()
    {
        return this.distributions;
    }

    /**
     * Setter for distributions
     * @param distributions
     * Collection of Distributions to estimate the optimal parameters of
     */
    public void setDistributions(
        Collection> distributions)
    {
        this.distributions = distributions;
    }

    @SuppressWarnings("unchecked")
    public ClosedFormComputableDistribution learn(
        Collection data)
    {

        ArrayList> tasks =
            new ArrayList>( this.distributions.size() );
        for( ClosedFormComputableDistribution distribution : this.getDistributions() )
        {
            tasks.add( new DistributionEstimationTask( 
                (ClosedFormComputableDistribution) distribution.clone(), data ) );
        }

        ArrayList>> results;

        try
        {
            results = ParallelUtil.executeInParallel(tasks,this.getThreadPool());
        }
        catch (Exception e)
        {
            System.out.println( "Exception: " + e );
            e.printStackTrace();
            results = null;
        }

        double minCost = Double.POSITIVE_INFINITY;
        ClosedFormComputableDistribution minDistribution = null;
        for( Pair> result : results )
        {
            double cost = result.getFirst();
            if( minCost > cost )
            {
                minCost = cost;
                minDistribution = result.getSecond();
            }
        }

        return minDistribution;
        
    }

    /**
     * Estimates the optimal parameters of a single distribution
     * @param 
     * Type of data emitted by the distribution
     */
    public static class DistributionEstimationTask
        extends AbstractCloneableSerializable
        implements Callable>>
    {
        
        /**
         * Distribution to estimate
         */
        ClosedFormComputableDistribution distribution;

        /**
         * Data to use in the estimation
         */
        Collection data;

        /**
         * Creates a new instance of DistributionEstimationTask
         * @param distribution
         * Distribution to estimate
         * @param data
         * Data to use in the estimation
         */
        public DistributionEstimationTask(
            ClosedFormComputableDistribution distribution,
            Collection data)
        {
            this.distribution = distribution;
            this.data = data;
        }

        @SuppressWarnings("unchecked")
        public Pair> call()
            throws Exception
        {
            try
            {
                ParallelNegativeLogLikelihood costFunction =
                    new ParallelNegativeLogLikelihood(this.data);

//                final int N = this.data.size();
////                final double tolerance = LineBracketInterpolatorBrent.DEFAULT_TOLERANCE / N;
//                final double tolerance = 1e-100;
//                LineBracketInterpolatorBrent brent = new LineBracketInterpolatorBrent();
//                brent.setTolerance(tolerance);
//                brent.getGoldenInterpolator().setTolerance(tolerance);
//                brent.getParabolicInterpolator().setTolerance(tolerance);
//                LineMinimizerDerivativeFree liner = new LineMinimizerDerivativeFree( brent );
//                liner.setTolerance(tolerance);
//                FunctionMinimizerDirectionSetPowell minimizer =
////                    new FunctionMinimizerDirectionSetPowell();
//                    new FunctionMinimizerDirectionSetPowell( liner );
//                minimizer.setTolerance(tolerance);

                // See if the initial parameterization is "in the ballpark"
                ClosedFormComputableDistribution result1 =
                    ObjectUtil.cloneSafe( this.distribution );
                double cost1 = costFunction.evaluate(result1);
                System.out.println( "Initial Cost: " + cost1 + ", Class: " + result1.getClass().getCanonicalName() + ", Parameters: " + result1.convertToVector() );
                
                // The initial parameters don't work, so guess some more
                if( Double.isInfinite(cost1) || Double.isNaN(cost1) )
                {
                    ClosedFormComputableDistribution result2 =
                        ObjectUtil.cloneSafe( this.distribution );

                    boolean bruteForce = true;

                    int Nsub = Math.min( 1000, this.data.size()/1000 );
//                    int Nsub = (int) Math.ceil( this.data.size() / 1000 );
                    Collection subList =
                        CollectionUtil.asArrayList(this.data).subList(0, Nsub);


                    // We've got a closed-form estimator... use that next
                    if( this.distribution instanceof EstimableDistribution )
                    {

                        DistributionEstimator solver =
                            ((EstimableDistribution) this.distribution).getEstimator();

                        try
                        {
                            result2 = (ClosedFormComputableDistribution) solver.learn( this.data );
                            double cost2 = costFunction.evaluate(result2);
                            System.out.println( "Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
                            bruteForce = (Double.isInfinite(cost2) || Double.isNaN(cost2));
                        }
                        catch (Exception e)
                        {
                            System.out.println( "Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e );
                            bruteForce = true;
                            result2 = ObjectUtil.cloneSafe(this.distribution);
                        }

                        if( bruteForce )
                        {
                            try
                            {
                                result2 = (ClosedFormComputableDistribution) solver.learn( subList );
                                double cost2 = costFunction.evaluate(result2);
                                System.out.println( "Sub-Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
                                bruteForce = (Double.isInfinite(cost2) || Double.isNaN(cost2));
                            }
                            catch (Exception e)
                            {
                                System.out.println( "Sub-Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e );
                                result2 = ObjectUtil.cloneSafe(this.distribution);
                            }

                        }

                    }

                    // Nothing has worked so far, Use Nelder-Mead, which is
                    // slow but is less susceptible to numerical imprecision
                    if( bruteForce )
                    {
                        FunctionMinimizerNelderMead minimizer1 =
                            new FunctionMinimizerNelderMead();
                        minimizer1.setMaxIterations(10);
                        minimizer1.setTolerance(1.0);

                        DistributionParameterEstimator> estimator2 =
                            new DistributionParameterEstimator>(
                                ObjectUtil.cloneSafe(result2), costFunction, minimizer1 );
                        result2 = estimator2.learn(this.data);
                        double cost2 = costFunction.evaluate(result2);
                        System.out.println( "Brute Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );

                        // Damn.. nothing has worked so far... subsample the
                        // data and re-estimate.
                        if( Double.isInfinite(cost2) || Double.isNaN(cost2) )
                        {
                            minimizer1.setMaxIterations(1000);
                            costFunction.setCostParameters(subList);

                            estimator2 = new DistributionParameterEstimator>(
                                ObjectUtil.cloneSafe(result2), costFunction, minimizer1 );
                            result2 = estimator2.learn(subList);

                            costFunction.setCostParameters(this.data);
                            double cost3 = costFunction.evaluate(result2);
                            System.out.println( "Subsample Cost: " + cost3 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );


                        }

                    }

                    result1 = result2;

                }

                FunctionMinimizerDirectionSetPowell minimizer3 =
                    new FunctionMinimizerDirectionSetPowell();
                DistributionParameterEstimator> estimator3 =
                    new DistributionParameterEstimator>(
                        ObjectUtil.cloneSafe(result1), costFunction, minimizer3 );
                ClosedFormComputableDistribution result3 =
                    estimator3.learn(this.data);



                double cost3 = costFunction.evaluate(result3);
                System.out.println( "Final Cost: " + cost3 + ", Class: " + result3.getClass().getCanonicalName() + ", Parameters: " + result3.convertToVector() );
                return DefaultPair.create( cost3, result3 );
            }
            catch (Exception e)
            {
                System.out.println( this.distribution.getClass().getCanonicalName() + " barfed: " + e );
//                e.printStackTrace();
                return DefaultPair.create( Double.POSITIVE_INFINITY, (ClosedFormComputableDistribution) this.distribution.clone() );
            }
        }

    }

    /**
     * Estimates a continuous distribution.
     *
     * @param data
     *      The data to estimate a distribution for.
     * @return
     *      The estimated distribution.
     * @throws Exception
     *      If there is an error in the estimation.
     */
    public static SmoothUnivariateDistribution estimateContinuousDistribution(
        Collection data )
        throws Exception
    {
        LinkedList distributions =
            getDistributionClasses( SmoothUnivariateDistribution.class );
        MaximumLikelihoodDistributionEstimator estimator =
            new MaximumLikelihoodDistributionEstimator( distributions );
        return (SmoothUnivariateDistribution) estimator.learn(data);
    }

    /**
     * Estimates a discrete distribution.
     *
     * @param data
     *      The data to estimate a distribution for.
     * @return
     *      The estimated distribution.
     * @throws Exception
     *      If there is an error in the estimation.
     */
    @SuppressWarnings("unchecked")
    public static ClosedFormDiscreteUnivariateDistribution estimateDiscreteDistribution(
        Collection data )
        throws Exception
    {

        LinkedList distributions =
            getDistributionClasses( ClosedFormDiscreteUnivariateDistribution.class );
        MaximumLikelihoodDistributionEstimator estimator =
            new MaximumLikelihoodDistributionEstimator(
            (Collection) distributions);
        return (ClosedFormDiscreteUnivariateDistribution) estimator.learn(data);


    }


    /**
     * Gets the distribution classes for the given base distribution.
     * 
     * @param   
     *      The type of distribution.
     * @param baseDistribution
     *      The class of the base distribution.
     * @return
     *      The list of implementations of that distribution in the statistics
     *      distribution package.
     * @throws ClassNotFoundException
     * @throws IOException
     * @throws InstantiationException
     * @throws IllegalAccessException
     */
    @SuppressWarnings("unchecked")
    protected static > LinkedList getDistributionClasses(
        Class baseDistribution )
        throws ClassNotFoundException, IOException, InstantiationException, IllegalAccessException
    {

        UnivariateGaussian g = new UnivariateGaussian();
        Package p = g.getClass().getPackage();
        LinkedList> cs = getClasses( p.getName() );
        LinkedList instances =
            new LinkedList();
        for( Class c : cs )
        {
            if( baseDistribution.isAssignableFrom( c ) )
            {
                if( ProbabilityFunction.class.isAssignableFrom(c) )
                {
                    try
                    {
                        instances.add( (DistributionType) c.newInstance());
                    }
                    catch (Exception e)
                    {
                        System.out.println( "Couldn't instantiate: " + c.getCanonicalName() );
                    }
                }
            }


        }

        return instances;

    }

    /**
     * Scans all classes accessible from the context class loader which belong to the given package and subpackages.
     *
     * @param packageName The base package
     * @return The classes
     * @throws ClassNotFoundException
     * @throws IOException
     */
    private static LinkedList> getClasses(
        String packageName)
        throws ClassNotFoundException, IOException
    {
        ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
        assert classLoader != null;
        String path = packageName.replace('.', '/');
        Enumeration resources = classLoader.getResources(path);
        List dirs = new ArrayList();
        while (resources.hasMoreElements())
        {
            URL resource = resources.nextElement();
            dirs.add(new File(resource.getFile()));
        }
        LinkedList> classes = new LinkedList>();
        for (File directory : dirs)
        {
            classes.addAll(findClasses(directory, packageName));
        }
        return classes;
    }

    /**
     * Recursive method used to find all classes in a given directory and subdirs.
     *
     * @param directory   The base directory
     * @param packageName The package name for classes found inside the base directory
     * @return The classes
     * @throws ClassNotFoundException
     */
    private static LinkedList> findClasses(
        File directory,
        String packageName)
        throws ClassNotFoundException
    {
        LinkedList> classes = new LinkedList>();
        File[] files = directory.listFiles();
        for (File file : files)
        {
            if (file.getName().endsWith(".class"))
            {
                classes.add(Class.forName(packageName + '.' + file.getName().substring(0, file.getName().length() - 6)));
            }
        }
        return classes;
    }
    
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy