
org.numenta.nupic.algorithms.MovingAverage Maven / Gradle / Ivy
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.algorithms;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
/**
* Helper class for computing moving average and sliding window
*
* @author Numenta
* @author David Ray
*/
public class MovingAverage {
private Calculation calc;
private int windowSize;
/**
* Constructs a new {@code MovingAverage}
*
* @param historicalValues list of entry values
* @param windowSize length over which to take the average
*/
public MovingAverage(TDoubleList historicalValues, int windowSize) {
this(historicalValues, -1, windowSize);
}
/**
* Constructs a new {@code MovingAverage}
*
* @param historicalValues list of entry values
* @param windowSize length over which to take the average
*/
public MovingAverage(TDoubleList historicalValues, double total, int windowSize) {
if(windowSize <= 0) {
throw new IllegalArgumentException("Window size must be > 0");
}
this.windowSize = windowSize;
calc = new Calculation();
calc.historicalValues =
historicalValues == null || historicalValues.size() < 1 ?
new TDoubleArrayList(windowSize) : historicalValues;
calc.total = total != -1 ? total : calc.historicalValues.sum();
}
/**
* Routine for computing a moving average
*
* @param slidingWindow a list of previous values to use in the computation that
* will be modified and returned
* @param total total the sum of the values in the slidingWindow to be used in the
* calculation of the moving average
* @param newVal newVal a new number to compute the new windowed average
* @param windowSize windowSize how many values to use in the moving window
* @return
*/
public static Calculation compute(TDoubleList slidingWindow, double total, double newVal, int windowSize) {
return compute(null, slidingWindow, total, newVal, windowSize);
}
/**
* Internal method which does actual calculation
*
* @param calc Re-used calculation object
* @param slidingWindow a list of previous values to use in the computation that
* will be modified and returned
* @param total total the sum of the values in the slidingWindow to be used in the
* calculation of the moving average
* @param newVal newVal a new number to compute the new windowed average
* @param windowSize windowSize how many values to use in the moving window
* @return
*/
private static Calculation compute(
Calculation calc, TDoubleList slidingWindow, double total, double newVal, int windowSize) {
if(slidingWindow == null) {
throw new IllegalArgumentException("slidingWindow cannot be null.");
}
if(slidingWindow.size() == windowSize) {
total -= slidingWindow.removeAt(0);
}
slidingWindow.add(newVal);
total += newVal;
if(calc == null) {
return new Calculation(slidingWindow, total / (double)slidingWindow.size(), total);
}
return copyInto(calc, slidingWindow, total / (double)slidingWindow.size(), total);
}
/**
* Called to compute the next moving average value.
*
* @param newValue new point data
* @return
*/
public double next(double newValue) {
compute(calc, calc.historicalValues, calc.total, newValue, windowSize);
return calc.average;
}
/**
* Returns the sliding window buffer used to calculate the moving average.
* @return
*/
public TDoubleList getSlidingWindow() {
return calc.historicalValues;
}
/**
* Returns the current running total
* @return
*/
public double getTotal() {
return calc.total;
}
/**
* Returns the size of the window over which the
* moving average is computed.
*
* @return
*/
public int getWindowSize() {
return windowSize;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((calc == null) ? 0 : calc.hashCode());
result = prime * result + windowSize;
return result;
}
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
MovingAverage other = (MovingAverage)obj;
if(calc == null) {
if(other.calc != null)
return false;
} else if(!calc.equals(other.calc))
return false;
if(windowSize != other.windowSize)
return false;
return true;
}
/**
* Internal method to update running totals.
*
* @param c
* @param slidingWindow
* @param value
* @param total
* @return
*/
private static Calculation copyInto(Calculation c, TDoubleList slidingWindow, double average, double total) {
c.historicalValues = slidingWindow;
c.average = average;
c.total = total;
return c;
}
/**
* Container for calculated data
*/
public static class Calculation {
private double average;
private TDoubleList historicalValues;
private double total;
public Calculation() {
}
public Calculation(TDoubleList historicalValues, double currentValue, double total) {
this.average = currentValue;
this.historicalValues = historicalValues;
this.total = total;
}
/**
* Returns the current value at this point in the calculation.
* @return
*/
public double getAverage() {
return average;
}
/**
* Returns a list of calculated values in the order of their
* calculation.
*
* @return
*/
public TDoubleList getHistoricalValues() {
return historicalValues;
}
/**
* Returns the total
* @return
*/
public double getTotal() {
return total;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
long temp;
temp = Double.doubleToLongBits(average);
result = prime * result + (int)(temp ^ (temp >>> 32));
result = prime * result + ((historicalValues == null) ? 0 : historicalValues.hashCode());
temp = Double.doubleToLongBits(total);
result = prime * result + (int)(temp ^ (temp >>> 32));
return result;
}
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
Calculation other = (Calculation)obj;
if(Double.doubleToLongBits(average) != Double.doubleToLongBits(other.average))
return false;
if(historicalValues == null) {
if(other.historicalValues != null)
return false;
} else if(!historicalValues.equals(other.historicalValues))
return false;
if(Double.doubleToLongBits(total) != Double.doubleToLongBits(other.total))
return false;
return true;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy