All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.method.DistributionParameterEstimator Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                DistributionParameterEstimator.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Jul 8, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.statistics.method;

import gov.sandia.cognition.algorithm.AnytimeAlgorithmWrapper;
import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizer;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerBFGS;
import gov.sandia.cognition.learning.function.cost.CostFunction;
import gov.sandia.cognition.math.DifferentiableEvaluator;
import gov.sandia.cognition.math.matrix.NumericalDifferentiator;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.ClosedFormDistribution;
import gov.sandia.cognition.statistics.method.DistributionParameterEstimator.DistributionWrapper;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

/**
 * A method of estimating the parameters of a distribution using an arbitrary
 * CostFunction and FunctionMinimizer algorithm.
 * @param 
 * Type of data generated by the distribution
 * @param 
 * Type of distribution to estimate the parameters of.
 * @author Kevin R. Dixon
 * @since 3.1
 */
public class DistributionParameterEstimator>
    extends AnytimeAlgorithmWrapper.DistributionWrapper>>
    implements BatchLearner,DistributionType>,
    MeasurablePerformanceAlgorithm
{

    /**
     * Function that maps a Distribution onto a Vector/Scalar function.
     */
    private DistributionWrapper distributionWrapper;

    /**
     * Distribution that minimizes the cost function.
     */
    private DistributionType result;

    /**
     * Creates a new instance of DistributionParameterEstimator
     * @param distribution
     * Distribution to estimate the parameters of
     * @param costFunction
     * Cost function to use in the minimization procedure
     */
    public DistributionParameterEstimator(
        DistributionType distribution,
        CostFunction> costFunction )
    {
        this( distribution, costFunction, new FunctionMinimizerBFGS() );
    }
    
    /** 
     * Creates a new instance of DistributionParameterEstimator 
     * @param distribution
     * Distribution to estimate the parameters of
     * @param costFunction
     * Cost function to use in the minimization procedure
     * @param algorithm
     * Minimization algorithm to use, such as FunctionMinimizerBFGS,
     * FunctionMinimizerDirectionSetPowell, etc.
     */
    public DistributionParameterEstimator(
        DistributionType distribution,
        CostFunction> costFunction,
        FunctionMinimizer.DistributionWrapper> algorithm )
    {
        super( algorithm );
        this.distributionWrapper =
            new DistributionWrapper( distribution, costFunction );
    }

    @Override
    public DistributionParameterEstimator clone()
    {
        @SuppressWarnings("unchecked")
        DistributionParameterEstimator clone =
            (DistributionParameterEstimator) super.clone();
        clone.distributionWrapper = ObjectUtil.cloneSafe( this.distributionWrapper );
        clone.result = ObjectUtil.cloneSafe( this.getResult() );
        return clone;
    }

    public DistributionType learn(
        Collection minimizationParameters)
    {
        DistributionWrapper wrapperClone = this.distributionWrapper.clone();
        wrapperClone.costFunction.setCostParameters(minimizationParameters);
        this.getAlgorithm().setInitialGuess( wrapperClone.distribution.convertToVector() );
        this.getAlgorithm().learn( wrapperClone );
        this.result = wrapperClone.distribution;
        return this.getResult();
    }

    public DistributionType getResult()
    {
        return this.result;
    }

    public NamedValue getPerformance()
    {
        double cost = (this.getAlgorithm().getResult() == null) ? 0.0 : this.getAlgorithm().getResult().getOutput();
        return new DefaultNamedValue( "Cost", cost );
    }


    /**
     * Maps the parameters of a Distribution and a CostFunction into a
     * Vector/Double Evaluator.
     */
    protected class DistributionWrapper
        extends AbstractCloneableSerializable
        implements Evaluator,
        DifferentiableEvaluator
    {


        /**
         * Distribution to estimate the parameters of
         */
        protected DistributionType distribution;

        /**
         * Cost function to use in the minimization procedure
         */
        protected CostFunction> costFunction;

        /**
         * Creates a new instance of DistributionWrapper
         * @param distribution
         * Distribution to estimate the parameters of
         * @param costFunction
         * Cost function to use in the minimization procedure
         */
        public DistributionWrapper(
            DistributionType distribution,
            CostFunction> costFunction)
        {
            this.distribution = distribution;
            this.costFunction = costFunction;
        }

        @Override
        public DistributionWrapper clone()
        {
            @SuppressWarnings("unchecked")
            DistributionWrapper clone = (DistributionWrapper) super.clone();
            clone.distribution = ObjectUtil.cloneSafe( this.distribution );
            clone.costFunction = ObjectUtil.cloneSafe( this.costFunction );
            return clone;
        }

        public Double evaluate(
            Vector input)
        {

            try
            {
                distribution.convertFromVector(input);
                return this.costFunction.evaluate(this.distribution);
            }
            catch (Exception e)
            {
                // Leave the distribution unchanged...
//                return this.costFunction.evaluate(this.distribution);
                return Double.POSITIVE_INFINITY;
//                return Double.MAX_VALUE;
            }
        }

        public Vector differentiate(
            Vector input)
        {
            return NumericalDifferentiator.VectorJacobian.differentiate(input,this);
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy