![JAR search and dependency download from the Maven repository](/logo.png)
gov.sandia.cognition.statistics.method.MaximumLikelihoodDistributionEstimator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: MaximumLikelihoodDistributionEstimator.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Jul 12, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.method;
import gov.sandia.cognition.algorithm.AbstractParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerDirectionSetPowell;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerNelderMead;
import gov.sandia.cognition.learning.function.cost.ParallelNegativeLogLikelihood;
import gov.sandia.cognition.statistics.ClosedFormComputableDistribution;
import gov.sandia.cognition.statistics.ClosedFormDiscreteUnivariateDistribution;
import gov.sandia.cognition.statistics.DistributionEstimator;
import gov.sandia.cognition.statistics.EstimableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.SmoothUnivariateDistribution;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
/**
* Estimates the most-likely distribution, and corresponding parameters, of
* that generated the given data from a pre-determined collection of
* candidate parameteric distributions.
* @param Type of data generated by the distributions.
* @author Kevin R. Dixon
* @since 3.1
*/
public class MaximumLikelihoodDistributionEstimator
extends AbstractParallelAlgorithm
implements BatchLearner,ClosedFormComputableDistribution>
{
/**
* Collection of Distributions to estimate the optimal parameters of
*/
private Collection extends ClosedFormComputableDistribution> distributions;
/**
* Creates a new instance of MaximumLikelihoodDistributionEstimator
*/
public MaximumLikelihoodDistributionEstimator()
{
this( null );
}
/**
* Creates a new instance of MaximumLikelihoodDistributionEstimator
* @param distributions
* Collection of Distributions to estimate the optimal parameters of
*/
public MaximumLikelihoodDistributionEstimator(
Collection extends ClosedFormComputableDistribution> distributions )
{
this.setDistributions(distributions);
}
@Override
public MaximumLikelihoodDistributionEstimator clone()
{
@SuppressWarnings("unchecked")
MaximumLikelihoodDistributionEstimator clone =
(MaximumLikelihoodDistributionEstimator) super.clone();
clone.setDistributions( ObjectUtil.cloneSmartElementsAsArrayList(
this.getDistributions() ) );
return clone;
}
/**
* Getter for distributions
* @return
* Collection of Distributions to estimate the optimal parameters of
*/
public Collection extends ClosedFormComputableDistribution> getDistributions()
{
return this.distributions;
}
/**
* Setter for distributions
* @param distributions
* Collection of Distributions to estimate the optimal parameters of
*/
public void setDistributions(
Collection extends ClosedFormComputableDistribution> distributions)
{
this.distributions = distributions;
}
@SuppressWarnings("unchecked")
public ClosedFormComputableDistribution learn(
Collection extends DataType> data)
{
ArrayList> tasks =
new ArrayList>( this.distributions.size() );
for( ClosedFormComputableDistribution distribution : this.getDistributions() )
{
tasks.add( new DistributionEstimationTask(
(ClosedFormComputableDistribution) distribution.clone(), data ) );
}
ArrayList>> results;
try
{
results = ParallelUtil.executeInParallel(tasks,this.getThreadPool());
}
catch (Exception e)
{
throw new RuntimeException(e);
}
double minCost = Double.POSITIVE_INFINITY;
ClosedFormComputableDistribution minDistribution = null;
for( Pair> result : results )
{
double cost = result.getFirst();
if( minCost > cost )
{
minCost = cost;
minDistribution = result.getSecond();
}
}
return minDistribution;
}
/**
* Estimates the optimal parameters of a single distribution
* @param
* Type of data emitted by the distribution
*/
public static class DistributionEstimationTask
extends AbstractCloneableSerializable
implements Callable>>
{
/**
* Distribution to estimate
*/
ClosedFormComputableDistribution distribution;
/**
* Data to use in the estimation
*/
Collection extends DataType> data;
/**
* Creates a new instance of DistributionEstimationTask
* @param distribution
* Distribution to estimate
* @param data
* Data to use in the estimation
*/
public DistributionEstimationTask(
ClosedFormComputableDistribution distribution,
Collection extends DataType> data)
{
this.distribution = distribution;
this.data = data;
}
@SuppressWarnings("unchecked")
public Pair> call()
throws Exception
{
try
{
ParallelNegativeLogLikelihood costFunction =
new ParallelNegativeLogLikelihood(this.data);
// final int N = this.data.size();
//// final double tolerance = LineBracketInterpolatorBrent.DEFAULT_TOLERANCE / N;
// final double tolerance = 1e-100;
// LineBracketInterpolatorBrent brent = new LineBracketInterpolatorBrent();
// brent.setTolerance(tolerance);
// brent.getGoldenInterpolator().setTolerance(tolerance);
// brent.getParabolicInterpolator().setTolerance(tolerance);
// LineMinimizerDerivativeFree liner = new LineMinimizerDerivativeFree( brent );
// liner.setTolerance(tolerance);
// FunctionMinimizerDirectionSetPowell minimizer =
//// new FunctionMinimizerDirectionSetPowell();
// new FunctionMinimizerDirectionSetPowell( liner );
// minimizer.setTolerance(tolerance);
// See if the initial parameterization is "in the ballpark"
ClosedFormComputableDistribution result1 =
ObjectUtil.cloneSafe( this.distribution );
double cost1 = costFunction.evaluate(result1);
// System.out.println( "Initial Cost: " + cost1 + ", Class: " + result1.getClass().getCanonicalName() + ", Parameters: " + result1.convertToVector() );
// The initial parameters don't work, so guess some more
if( Double.isInfinite(cost1) || Double.isNaN(cost1) )
{
ClosedFormComputableDistribution result2 =
ObjectUtil.cloneSafe( this.distribution );
boolean bruteForce = true;
int Nsub = Math.min( 1000, this.data.size()/1000 );
// int Nsub = (int) Math.ceil( this.data.size() / 1000 );
Collection extends DataType> subList =
CollectionUtil.asArrayList(this.data).subList(0, Nsub);
// We've got a closed-form estimator... use that next
if( this.distribution instanceof EstimableDistribution )
{
DistributionEstimator> solver =
((EstimableDistribution) this.distribution).getEstimator();
try
{
result2 = solver.learn( this.data );
double cost2 = costFunction.evaluate(result2);
// System.out.println( "Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
bruteForce = (Double.isInfinite(cost2) || Double.isNaN(cost2));
}
catch (Exception e)
{
// System.out.println( "Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e );
bruteForce = true;
result2 = ObjectUtil.cloneSafe(this.distribution);
}
if( bruteForce )
{
try
{
result2 = solver.learn( subList );
double cost2 = costFunction.evaluate(result2);
// System.out.println( "Sub-Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
bruteForce = (Double.isInfinite(cost2) || Double.isNaN(cost2));
}
catch (Exception e)
{
// System.out.println( "Sub-Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e );
result2 = ObjectUtil.cloneSafe(this.distribution);
}
}
}
// Nothing has worked so far, Use Nelder-Mead, which is
// slow but is less susceptible to numerical imprecision
if( bruteForce )
{
FunctionMinimizerNelderMead minimizer1 =
new FunctionMinimizerNelderMead();
minimizer1.setMaxIterations(10);
minimizer1.setTolerance(1.0);
DistributionParameterEstimator> estimator2 =
new DistributionParameterEstimator>(
ObjectUtil.cloneSafe(result2), costFunction, minimizer1 );
result2 = estimator2.learn(this.data);
double cost2 = costFunction.evaluate(result2);
// System.out.println( "Brute Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
// Damn.. nothing has worked so far... subsample the
// data and re-estimate.
if( Double.isInfinite(cost2) || Double.isNaN(cost2) )
{
minimizer1.setMaxIterations(1000);
costFunction.setCostParameters(subList);
estimator2 = new DistributionParameterEstimator>(
ObjectUtil.cloneSafe(result2), costFunction, minimizer1 );
result2 = estimator2.learn(subList);
costFunction.setCostParameters(this.data);
double cost3 = costFunction.evaluate(result2);
// System.out.println( "Subsample Cost: " + cost3 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
}
}
result1 = result2;
}
FunctionMinimizerDirectionSetPowell minimizer3 =
new FunctionMinimizerDirectionSetPowell();
DistributionParameterEstimator> estimator3 =
new DistributionParameterEstimator>(
ObjectUtil.cloneSafe(result1), costFunction, minimizer3 );
ClosedFormComputableDistribution result3 =
estimator3.learn(this.data);
double cost3 = costFunction.evaluate(result3);
// System.out.println( "Final Cost: " + cost3 + ", Class: " + result3.getClass().getCanonicalName() + ", Parameters: " + result3.convertToVector() );
return DefaultPair.create( cost3, result3 );
}
catch (Exception e)
{
// System.out.println( this.distribution.getClass().getCanonicalName() + " barfed: " + e );
// e.printStackTrace();
return DefaultPair.create( Double.POSITIVE_INFINITY, (ClosedFormComputableDistribution) this.distribution.clone() );
}
}
}
/**
* Estimates a continuous distribution.
*
* @param data
* The data to estimate a distribution for.
* @return
* The estimated distribution.
* @throws Exception
* If there is an error in the estimation.
*/
public static SmoothUnivariateDistribution estimateContinuousDistribution(
Collection data )
throws Exception
{
LinkedList distributions =
getDistributionClasses( SmoothUnivariateDistribution.class );
MaximumLikelihoodDistributionEstimator estimator =
new MaximumLikelihoodDistributionEstimator( distributions );
return (SmoothUnivariateDistribution) estimator.learn(data);
}
/**
* Estimates a discrete distribution.
*
* @param data
* The data to estimate a distribution for.
* @return
* The estimated distribution.
* @throws Exception
* If there is an error in the estimation.
*/
@SuppressWarnings(value={"unchecked", "rawtypes"})
public static ClosedFormDiscreteUnivariateDistribution estimateDiscreteDistribution(
Collection extends Number> data )
throws Exception
{
LinkedList distributions =
getDistributionClasses( ClosedFormDiscreteUnivariateDistribution.class );
MaximumLikelihoodDistributionEstimator estimator =
new MaximumLikelihoodDistributionEstimator(
(Collection extends ClosedFormComputableDistribution>) distributions);
return (ClosedFormDiscreteUnivariateDistribution) estimator.learn(data);
}
/**
* Gets the distribution classes for the given base distribution.
*
* @param
* The type of distribution.
* @param baseDistribution
* The class of the base distribution.
* @return
* The list of implementations of that distribution in the statistics
* distribution package.
* @throws ClassNotFoundException
* @throws IOException
* @throws InstantiationException
* @throws IllegalAccessException
*/
@SuppressWarnings("unchecked")
protected static > LinkedList getDistributionClasses(
Class extends DistributionType> baseDistribution )
throws ClassNotFoundException, IOException, InstantiationException, IllegalAccessException
{
UnivariateGaussian g = new UnivariateGaussian();
Package p = g.getClass().getPackage();
LinkedList> cs = getClasses( p.getName() );
LinkedList instances =
new LinkedList();
for( Class> c : cs )
{
if( baseDistribution.isAssignableFrom( c ) )
{
if( ProbabilityFunction.class.isAssignableFrom(c) )
{
try
{
instances.add( (DistributionType) c.newInstance());
}
catch (Exception e)
{
// System.out.println( "Couldn't instantiate: " + c.getCanonicalName() );
}
}
}
}
return instances;
}
/**
* Scans all classes accessible from the context class loader which belong to the given package and subpackages.
*
* @param packageName The base package
* @return The classes
* @throws ClassNotFoundException
* @throws IOException
*/
private static LinkedList> getClasses(
String packageName)
throws ClassNotFoundException, IOException
{
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
assert classLoader != null;
String path = packageName.replace('.', '/');
Enumeration resources = classLoader.getResources(path);
List dirs = new ArrayList();
while (resources.hasMoreElements())
{
URL resource = resources.nextElement();
dirs.add(new File(resource.getFile()));
}
LinkedList> classes = new LinkedList>();
for (File directory : dirs)
{
classes.addAll(findClasses(directory, packageName));
}
return classes;
}
/**
* Recursive method used to find all classes in a given directory and subdirs.
*
* @param directory The base directory
* @param packageName The package name for classes found inside the base directory
* @return The classes
* @throws ClassNotFoundException
*/
private static LinkedList> findClasses(
File directory,
String packageName)
throws ClassNotFoundException
{
LinkedList> classes = new LinkedList>();
File[] files = directory.listFiles();
for (File file : files)
{
if (file.getName().endsWith(".class"))
{
classes.add(Class.forName(packageName + '.' + file.getName().substring(0, file.getName().length() - 6)));
}
}
return classes;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy