All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.distribution.DefaultDataDistribution Maven / Gradle / Ivy

/*
 * File:                DefaultDataDistribution.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Incremental Learning Core
 * 
 * Copyright June 15, 2011, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government.
 *
 */

package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.factory.Factory;
import gov.sandia.cognition.learning.algorithm.AbstractBatchAndIncrementalLearner;
import gov.sandia.cognition.math.MutableDouble;
import gov.sandia.cognition.statistics.AbstractDataDistribution;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.DistributionEstimator;
import gov.sandia.cognition.statistics.DistributionWeightedEstimator;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.WeightedValue;
import java.util.LinkedHashMap;
import java.util.Map;

/**
 * A default implementation of {@code ScalarDataDistribution} that uses a
 * backing map.
 * 
 * @param 
 * Type of Key in the distribution
 * @author  Justin Basilico
 * @since   3.1.2
 */
public class DefaultDataDistribution
    extends AbstractDataDistribution
{

    /**
     * Default initial capacity, {@value}.
     */
    public static final int DEFAULT_INITIAL_CAPACITY = 10;

    /**
     * Total of the counts in the distribution
     */
    protected double total;

    /**
     * Default constructor
     */
    public DefaultDataDistribution()
    {
        this( DEFAULT_INITIAL_CAPACITY );
    }

    /**
     * Creates a new instance of DefaultDataDistribution
     * @param initialCapacity
     * Initial capacity of the Map
     */
    public DefaultDataDistribution(
        int initialCapacity)
    {
        this( new LinkedHashMap( initialCapacity), 0.0 );
    }

    /**
     * Creates a new instance of DefaultDataDistribution
     * @param other
     * DataDistribution to copy
     */
    public DefaultDataDistribution(
        final DataDistribution other)
    {
        this(new LinkedHashMap(other.size()), 0.0);
        this.incrementAll(other);
    }

    /**
     * Creates a new instance of ScalarDataDistribution
     * @param data
     * Data to create the distribution
     */
    public DefaultDataDistribution(
        final Iterable data )
    {
        this();
        this.incrementAll(data);
    }

    /**
     * Creates a new instance of
     * @param map
     * Backing Map that stores the data
     * @param total
     * Sum of all values in the Map
     */
    protected DefaultDataDistribution(
        final Map map,
        final double total)
    {
        super( map );
        this.total = total;
    }

    @Override
    public DefaultDataDistribution clone()
    {
        DefaultDataDistribution clone =
            (DefaultDataDistribution) super.clone();
        
        // We have to manually reset "total" because super.super.clone
        // calls "incrementAll", which will, in turn, increment the total
        // So we'd end up with twice the total.
        clone.total = this.total;
        return clone;
    }

    @Override
    public double increment(
        KeyType key,
        final double value)
    {
        final MutableDouble entry = this.map.get(key);
        double newValue = 0.0;
        double delta;
        if( entry == null )
        {
            if( value > 0.0 )
            {
                // It's best to avoid this.set() here as it could mess up
                // our total tracker in some subclasses...
                // Also it's more efficient this way (avoid another get)
                this.map.put( key, new MutableDouble(value) );
                newValue = value;
                delta = value;
            }
            else
            {
                delta = 0.0;
            }
        }
        else
        {
            if( entry.value+value >= 0.0 )
            {
                delta = value;
                entry.value += value;
                newValue = entry.value;
            }
            else
            {
                delta = -entry.value;
                entry.value = 0.0;
            }
        }

        this.total += delta;
        return newValue;
    }

    @Override
    public void set(
        final KeyType key,
        final double value)
    {

        // I decided not to call super.set because it would result in me
        // having to perform an extra call to this.map.get
        final MutableDouble entry = this.map.get(key);
        if( entry == null )
        {
            // Only need to allocate if it's not null
            if( value > 0.0 )
            {
                this.map.put( key, new MutableDouble( value ) );
                this.total += value;
            }
        }
        else if( value > 0.0 )
        {
            this.total += value - entry.value;
            entry.value = value;
        }
        else
        {
            this.total -= entry.value;
            entry.value = 0.0;
        }
    }

    @Override
    public double getTotal()
    {
        return this.total;
    }
    
    @Override
    public void clear()
    {
        super.clear();
        this.total = 0.0;
    }

    @Override
    public DistributionEstimator> getEstimator()
    {
        return new DefaultDataDistribution.Estimator();
    }

    @Override
    public DataDistribution.PMF getProbabilityFunction()
    {
        return new DefaultDataDistribution.PMF(this);
    }

    /**
     * Gets the average value of all keys in the distribution, that is, the
     * total value divided by the number of keys (even zero-value keys)
     * @return
     * Average value of all keys in the distribution
     */
    public double getMeanValue()
    {
        final int ds = this.getDomainSize();
        if( ds > 0 )
        {
            return this.getTotal() / ds;
        }
        else
        {
            return 0.0;
        }
    }

    /**
     * PMF of the DefaultDataDistribution
     * @param 
     * Type of Key in the distribution
     */
    public static class PMF
        extends DefaultDataDistribution
        implements DataDistribution.PMF
    {

        /**
         * Default constructor
         */
        public PMF()
        {
            super();
        }

        /**
         * Copy constructor
         * @param other
         * ScalarDataDistribution to copy
         */
        public PMF(
            final DataDistribution other)
        {
            super(other);
        }

        /**
         * Creates a new instance of DefaultDataDistribution
         * @param initialCapacity
         * Initial capacity of the Map
         */
        public PMF(
            int initialCapacity)
        {
            super( initialCapacity );
        }

        /**
         * Creates a new instance of ScalarDataDistribution
         * @param data
         * Data to create the distribution
         */
        public PMF(
            final Iterable data )
        {
            super( data );
        }

        @Override
        public double logEvaluate(
            KeyType input)
        {
            return this.getLogFraction(input);
        }

        @Override
        public Double evaluate(
            KeyType input)
        {
            return this.getFraction(input);
        }

        @Override
        public DefaultDataDistribution.PMF getProbabilityFunction()
        {
            return this;
        }

    }

    /**
     * Estimator for a DefaultDataDistribution
     * @param 
     * Type of Key in the distribution
     */
    public static class Estimator
        extends AbstractBatchAndIncrementalLearner>
        implements DistributionEstimator>
    {

        /**
         * Default constructor
         */
        public Estimator()
        {
            super();
        }

        @Override
        public DefaultDataDistribution.PMF createInitialLearnedObject()
        {
            return new DefaultDataDistribution.PMF();
        }

        @Override
        public void update(
            final DefaultDataDistribution.PMF target,
            final KeyType data)
        {
            target.increment(data, 1.0);
        }

    }

    /**
     * A weighted estimator for a DefaultDataDistribution
     * @param 
     * Type of Key in the distribution
     */
    public static class WeightedEstimator
        extends AbstractBatchAndIncrementalLearner, DefaultDataDistribution.PMF>
        implements DistributionWeightedEstimator>
    {

        /**
         * Default constructor
         */
        public WeightedEstimator()
        {
            super();
        }

        @Override
        public DefaultDataDistribution.PMF createInitialLearnedObject()
        {
            return new DefaultDataDistribution.PMF();
        }

        @Override
        public void update(
            final DefaultDataDistribution.PMF target,
            final WeightedValue data)
        {
            target.increment( data.getValue(), data.getWeight() );
        }
        
    }

    /**
     * A factory for {@code DefaultDataDistribution} objects using some given
     * initial capacity for them.
     *
     * @param   
     *      The type of data for the factory.
     */
    public static class DefaultFactory
        extends AbstractCloneableSerializable
        implements Factory>
    {

        /** The initial domain capacity. */
        protected int initialDomainCapacity;

        /**
         * Creates a new {@code DefaultFactory} with a default
         * initial domain capacity.
         */
        public DefaultFactory()
        {
            this(DEFAULT_INITIAL_CAPACITY);
        }

        /**
         * Creates a new {@code DefaultFactory} with a given
         * initial domain capacity.
         *
         * @param   initialDomainCapacity
         *      The initial capacity for the domain. Must be positive.
         */
        public DefaultFactory(
            final int initialDomainCapacity)
        {
            super();

            this.setInitialDomainCapacity(initialDomainCapacity);
        }

        @Override
        public DefaultDataDistribution create()
        {
            // Create the histogram.
            return new DefaultDataDistribution(
                this.getInitialDomainCapacity());
        }

        /**
         * Gets the initial domain capacity.
         *
         * @return
         *      The initial domain capacity. Must be positive.
         */
        public int getInitialDomainCapacity()
        {
            return this.initialDomainCapacity;
        }

        /**
         * Sets the initial domain capacity.
         *
         * @param   initialDomainCapacity
         *      The initial domain capacity. Must be positive.
         */
        public void setInitialDomainCapacity(
            final int initialDomainCapacity)
        {
            ArgumentChecker.assertIsPositive("initialDomainCapacity",
                initialDomainCapacity);
            this.initialDomainCapacity = initialDomainCapacity;
        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy