gov.sandia.cognition.statistics.method.DistributionParameterEstimator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: DistributionParameterEstimator.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Jul 8, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.method;
import gov.sandia.cognition.algorithm.AnytimeAlgorithmWrapper;
import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizer;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerBFGS;
import gov.sandia.cognition.learning.function.cost.CostFunction;
import gov.sandia.cognition.math.DifferentiableEvaluator;
import gov.sandia.cognition.math.matrix.NumericalDifferentiator;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.ClosedFormDistribution;
import gov.sandia.cognition.statistics.method.DistributionParameterEstimator.DistributionWrapper;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;
/**
* A method of estimating the parameters of a distribution using an arbitrary
* CostFunction and FunctionMinimizer algorithm.
* @param
* Type of data generated by the distribution
* @param
* Type of distribution to estimate the parameters of.
* @author Kevin R. Dixon
* @since 3.1
*/
public class DistributionParameterEstimator>
extends AnytimeAlgorithmWrapper.DistributionWrapper>>
implements BatchLearner,DistributionType>,
MeasurablePerformanceAlgorithm
{
/**
* Function that maps a Distribution onto a Vector/Scalar function.
*/
private DistributionWrapper distributionWrapper;
/**
* Distribution that minimizes the cost function.
*/
private DistributionType result;
/**
* Creates a new instance of DistributionParameterEstimator
* @param distribution
* Distribution to estimate the parameters of
* @param costFunction
* Cost function to use in the minimization procedure
*/
public DistributionParameterEstimator(
DistributionType distribution,
CostFunction super DistributionType,Collection extends DataType>> costFunction )
{
this( distribution, costFunction, new FunctionMinimizerBFGS() );
}
/**
* Creates a new instance of DistributionParameterEstimator
* @param distribution
* Distribution to estimate the parameters of
* @param costFunction
* Cost function to use in the minimization procedure
* @param algorithm
* Minimization algorithm to use, such as FunctionMinimizerBFGS,
* FunctionMinimizerDirectionSetPowell, etc.
*/
public DistributionParameterEstimator(
DistributionType distribution,
CostFunction super DistributionType,Collection extends DataType>> costFunction,
FunctionMinimizer.DistributionWrapper> algorithm )
{
super( algorithm );
this.distributionWrapper =
new DistributionWrapper( distribution, costFunction );
}
@Override
public DistributionParameterEstimator clone()
{
@SuppressWarnings("unchecked")
DistributionParameterEstimator clone =
(DistributionParameterEstimator) super.clone();
clone.distributionWrapper = ObjectUtil.cloneSafe( this.distributionWrapper );
clone.result = ObjectUtil.cloneSafe( this.getResult() );
return clone;
}
public DistributionType learn(
Collection extends DataType> minimizationParameters)
{
DistributionWrapper wrapperClone = this.distributionWrapper.clone();
wrapperClone.costFunction.setCostParameters(minimizationParameters);
this.getAlgorithm().setInitialGuess( wrapperClone.distribution.convertToVector() );
this.getAlgorithm().learn( wrapperClone );
this.result = wrapperClone.distribution;
return this.getResult();
}
public DistributionType getResult()
{
return this.result;
}
public NamedValue extends Number> getPerformance()
{
double cost = (this.getAlgorithm().getResult() == null) ? 0.0 : this.getAlgorithm().getResult().getOutput();
return new DefaultNamedValue( "Cost", cost );
}
/**
* Maps the parameters of a Distribution and a CostFunction into a
* Vector/Double Evaluator.
*/
protected class DistributionWrapper
extends AbstractCloneableSerializable
implements Evaluator,
DifferentiableEvaluator
{
/**
* Distribution to estimate the parameters of
*/
protected DistributionType distribution;
/**
* Cost function to use in the minimization procedure
*/
protected CostFunction super DistributionType, ? super Collection extends DataType>> costFunction;
/**
* Creates a new instance of DistributionWrapper
* @param distribution
* Distribution to estimate the parameters of
* @param costFunction
* Cost function to use in the minimization procedure
*/
public DistributionWrapper(
DistributionType distribution,
CostFunction super DistributionType, ? super Collection extends DataType>> costFunction)
{
this.distribution = distribution;
this.costFunction = costFunction;
}
@Override
public DistributionWrapper clone()
{
@SuppressWarnings("unchecked")
DistributionWrapper clone = (DistributionWrapper) super.clone();
clone.distribution = ObjectUtil.cloneSafe( this.distribution );
clone.costFunction = ObjectUtil.cloneSafe( this.costFunction );
return clone;
}
public Double evaluate(
Vector input)
{
try
{
distribution.convertFromVector(input);
return this.costFunction.evaluate(this.distribution);
}
catch (Exception e)
{
// Leave the distribution unchanged...
// return this.costFunction.evaluate(this.distribution);
return Double.POSITIVE_INFINITY;
// return Double.MAX_VALUE;
}
}
public Vector differentiate(
Vector input)
{
return NumericalDifferentiator.VectorJacobian.differentiate(input,this);
}
}
}