All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.numenta.nupic.algorithms.MovingAverage Maven / Gradle / Ivy

/* ---------------------------------------------------------------------
 * Numenta Platform for Intelligent Computing (NuPIC)
 * Copyright (C) 2014, Numenta, Inc.  Unless you have an agreement
 * with Numenta, Inc., for a separate license for this software code, the
 * following terms and conditions apply:
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU Affero Public License for more details.
 *
 * You should have received a copy of the GNU Affero Public License
 * along with this program.  If not, see http://www.gnu.org/licenses.
 *
 * http://numenta.org/licenses/
 * ---------------------------------------------------------------------
 */

package org.numenta.nupic.algorithms;

import org.numenta.nupic.model.Persistable;

import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;

/**
 * Helper class for computing moving average and sliding window
 * 
 * @author Numenta
 * @author David Ray
 */
public class MovingAverage implements Persistable {
    private static final long serialVersionUID = 1L;

    private Calculation calc;
    
    private int windowSize;
    
    /**
     * Constructs a new {@code MovingAverage}
     * 
     * @param historicalValues  list of entry values
     * @param windowSize        length over which to take the average
     */
    public MovingAverage(TDoubleList historicalValues, int windowSize) {
        this(historicalValues, -1, windowSize);
    }
    
    /**
     * Constructs a new {@code MovingAverage}
     * 
     * @param historicalValues  list of entry values
     * @param windowSize        length over which to take the average
     */
    public MovingAverage(TDoubleList historicalValues, double total, int windowSize) {
        if(windowSize <= 0) {
            throw new IllegalArgumentException("Window size must be > 0");
        }
        this.windowSize = windowSize;
        
        calc = new Calculation();
        calc.historicalValues = 
            historicalValues == null || historicalValues.size() < 1 ?
                new TDoubleArrayList(windowSize) : historicalValues;
        calc.total = total != -1 ? total : calc.historicalValues.sum();
    }
    
    /**
     * Routine for computing a moving average
     * 
     * @param slidingWindow     a list of previous values to use in the computation that
     *                          will be modified and returned
     * @param total             total the sum of the values in the  slidingWindow to be used in the
     *                          calculation of the moving average
     * @param newVal            newVal a new number to compute the new windowed average
     * @param windowSize        windowSize how many values to use in the moving window
     * @return
     */
    public static Calculation compute(TDoubleList slidingWindow, double total, double newVal, int windowSize) {
        return compute(null, slidingWindow, total, newVal, windowSize);
    }
    
    /**
     * Internal method which does actual calculation
     * 
     * @param calc              Re-used calculation object
     * @param slidingWindow     a list of previous values to use in the computation that
     *                          will be modified and returned
     * @param total             total the sum of the values in the  slidingWindow to be used in the
     *                          calculation of the moving average
     * @param newVal            newVal a new number to compute the new windowed average
     * @param windowSize        windowSize how many values to use in the moving window
     * @return
     */
    private static Calculation compute(
        Calculation calc, TDoubleList slidingWindow, double total, double newVal, int windowSize) {
        
        if(slidingWindow == null) {
            throw new IllegalArgumentException("slidingWindow cannot be null.");
        }
        
        if(slidingWindow.size() == windowSize) {
            total -= slidingWindow.removeAt(0);
        }
        slidingWindow.add(newVal);
        total += newVal;
        
        if(calc == null) {
            return new Calculation(slidingWindow, total / (double)slidingWindow.size(), total);
        }
        
        return copyInto(calc, slidingWindow, total / (double)slidingWindow.size(), total);
    }
    
    /**
     * Called to compute the next moving average value.
     * 
     * @param newValue  new point data
     * @return
     */
    public double next(double newValue) {
        compute(calc, calc.historicalValues, calc.total, newValue, windowSize);
        return calc.average;
    }
    
    /**
     * Returns the sliding window buffer used to calculate the moving average.
     * @return
     */
    public TDoubleList getSlidingWindow() {
        return calc.historicalValues;
    }
    
    /**
     * Returns the current running total
     * @return
     */
    public double getTotal() {
        return calc.total;
    }
    
    /**
     * Returns the size of the window over which the 
     * moving average is computed.
     * 
     * @return
     */
    public int getWindowSize() {
        return windowSize;
    }
    
    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + ((calc == null) ? 0 : calc.hashCode());
        result = prime * result + windowSize;
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if(this == obj)
            return true;
        if(obj == null)
            return false;
        if(getClass() != obj.getClass())
            return false;
        MovingAverage other = (MovingAverage)obj;
        if(calc == null) {
            if(other.calc != null)
                return false;
        } else if(!calc.equals(other.calc))
            return false;
        if(windowSize != other.windowSize)
            return false;
        return true;
    }

    /**
     * Internal method to update running totals.
     * 
     * @param c
     * @param slidingWindow
     * @param value
     * @param total
     * @return
     */
    private static Calculation copyInto(Calculation c, TDoubleList slidingWindow, double average, double total) {
        c.historicalValues = slidingWindow;
        c.average = average;
        c.total = total;
        return c;
    }
    
    /**
     * Container for calculated data
     */
    public static class Calculation implements Persistable {
        private static final long serialVersionUID = 1L;
        
        private  double average;
        private TDoubleList historicalValues;
        private double total;
        
        public Calculation() {
            
        }
        
        public Calculation(TDoubleList historicalValues, double currentValue, double total) {
            this.average = currentValue;
            this.historicalValues = historicalValues;
            this.total = total;
        }
        
        /**
         * Returns the current value at this point in the calculation.
         * @return
         */
        public double getAverage() {
            return average;
        }
        
        /**
         * Returns a list of calculated values in the order of their
         * calculation.
         * 
         * @return
         */
        public TDoubleList getHistoricalValues() {
            return historicalValues;
        }
        
        /**
         * Returns the total
         * @return
         */
        public double getTotal() {
            return total;
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            long temp;
            temp = Double.doubleToLongBits(average);
            result = prime * result + (int)(temp ^ (temp >>> 32));
            result = prime * result + ((historicalValues == null) ? 0 : historicalValues.hashCode());
            temp = Double.doubleToLongBits(total);
            result = prime * result + (int)(temp ^ (temp >>> 32));
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if(this == obj)
                return true;
            if(obj == null)
                return false;
            if(getClass() != obj.getClass())
                return false;
            Calculation other = (Calculation)obj;
            if(Double.doubleToLongBits(average) != Double.doubleToLongBits(other.average))
                return false;
            if(historicalValues == null) {
                if(other.historicalValues != null)
                    return false;
            } else if(!historicalValues.equals(other.historicalValues))
                return false;
            if(Double.doubleToLongBits(total) != Double.doubleToLongBits(other.total))
                return false;
            return true;
        }
        
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy