
org.numenta.nupic.algorithms.MovingAverage Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of htm.java Show documentation
Show all versions of htm.java Show documentation
The Java version of Numenta's HTM technology
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.algorithms;
import org.numenta.nupic.model.Persistable;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
/**
* Helper class for computing moving average and sliding window
*
* @author Numenta
* @author David Ray
*/
public class MovingAverage implements Persistable {
private static final long serialVersionUID = 1L;
private Calculation calc;
private int windowSize;
/**
* Constructs a new {@code MovingAverage}
*
* @param historicalValues list of entry values
* @param windowSize length over which to take the average
*/
public MovingAverage(TDoubleList historicalValues, int windowSize) {
this(historicalValues, -1, windowSize);
}
/**
* Constructs a new {@code MovingAverage}
*
* @param historicalValues list of entry values
* @param windowSize length over which to take the average
*/
public MovingAverage(TDoubleList historicalValues, double total, int windowSize) {
if(windowSize <= 0) {
throw new IllegalArgumentException("Window size must be > 0");
}
this.windowSize = windowSize;
calc = new Calculation();
calc.historicalValues =
historicalValues == null || historicalValues.size() < 1 ?
new TDoubleArrayList(windowSize) : historicalValues;
calc.total = total != -1 ? total : calc.historicalValues.sum();
}
/**
* Routine for computing a moving average
*
* @param slidingWindow a list of previous values to use in the computation that
* will be modified and returned
* @param total total the sum of the values in the slidingWindow to be used in the
* calculation of the moving average
* @param newVal newVal a new number to compute the new windowed average
* @param windowSize windowSize how many values to use in the moving window
* @return
*/
public static Calculation compute(TDoubleList slidingWindow, double total, double newVal, int windowSize) {
return compute(null, slidingWindow, total, newVal, windowSize);
}
/**
* Internal method which does actual calculation
*
* @param calc Re-used calculation object
* @param slidingWindow a list of previous values to use in the computation that
* will be modified and returned
* @param total total the sum of the values in the slidingWindow to be used in the
* calculation of the moving average
* @param newVal newVal a new number to compute the new windowed average
* @param windowSize windowSize how many values to use in the moving window
* @return
*/
private static Calculation compute(
Calculation calc, TDoubleList slidingWindow, double total, double newVal, int windowSize) {
if(slidingWindow == null) {
throw new IllegalArgumentException("slidingWindow cannot be null.");
}
if(slidingWindow.size() == windowSize) {
total -= slidingWindow.removeAt(0);
}
slidingWindow.add(newVal);
total += newVal;
if(calc == null) {
return new Calculation(slidingWindow, total / (double)slidingWindow.size(), total);
}
return copyInto(calc, slidingWindow, total / (double)slidingWindow.size(), total);
}
/**
* Called to compute the next moving average value.
*
* @param newValue new point data
* @return
*/
public double next(double newValue) {
compute(calc, calc.historicalValues, calc.total, newValue, windowSize);
return calc.average;
}
/**
* Returns the sliding window buffer used to calculate the moving average.
* @return
*/
public TDoubleList getSlidingWindow() {
return calc.historicalValues;
}
/**
* Returns the current running total
* @return
*/
public double getTotal() {
return calc.total;
}
/**
* Returns the size of the window over which the
* moving average is computed.
*
* @return
*/
public int getWindowSize() {
return windowSize;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((calc == null) ? 0 : calc.hashCode());
result = prime * result + windowSize;
return result;
}
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
MovingAverage other = (MovingAverage)obj;
if(calc == null) {
if(other.calc != null)
return false;
} else if(!calc.equals(other.calc))
return false;
if(windowSize != other.windowSize)
return false;
return true;
}
/**
* Internal method to update running totals.
*
* @param c
* @param slidingWindow
* @param value
* @param total
* @return
*/
private static Calculation copyInto(Calculation c, TDoubleList slidingWindow, double average, double total) {
c.historicalValues = slidingWindow;
c.average = average;
c.total = total;
return c;
}
/**
* Container for calculated data
*/
public static class Calculation implements Persistable {
private static final long serialVersionUID = 1L;
private double average;
private TDoubleList historicalValues;
private double total;
public Calculation() {
}
public Calculation(TDoubleList historicalValues, double currentValue, double total) {
this.average = currentValue;
this.historicalValues = historicalValues;
this.total = total;
}
/**
* Returns the current value at this point in the calculation.
* @return
*/
public double getAverage() {
return average;
}
/**
* Returns a list of calculated values in the order of their
* calculation.
*
* @return
*/
public TDoubleList getHistoricalValues() {
return historicalValues;
}
/**
* Returns the total
* @return
*/
public double getTotal() {
return total;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
long temp;
temp = Double.doubleToLongBits(average);
result = prime * result + (int)(temp ^ (temp >>> 32));
result = prime * result + ((historicalValues == null) ? 0 : historicalValues.hashCode());
temp = Double.doubleToLongBits(total);
result = prime * result + (int)(temp ^ (temp >>> 32));
return result;
}
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
Calculation other = (Calculation)obj;
if(Double.doubleToLongBits(average) != Double.doubleToLongBits(other.average))
return false;
if(historicalValues == null) {
if(other.historicalValues != null)
return false;
} else if(!historicalValues.equals(other.historicalValues))
return false;
if(Double.doubleToLongBits(total) != Double.doubleToLongBits(other.total))
return false;
return true;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy