weka.estimators.NormalEstimator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* NormalEstimator.java
* Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.estimators;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Aggregateable;
import weka.core.RevisionUtils;
import weka.core.Statistics;
import weka.core.Utils;
/**
* Simple probability estimator that places a single normal distribution over
* the observed values.
*
* @author Len Trigg ([email protected])
* @version $Revision: 9785 $
*/
public class NormalEstimator extends Estimator implements IncrementalEstimator,
Aggregateable {
/** for serialization */
private static final long serialVersionUID = 93584379632315841L;
/** The sum of the weights */
private double m_SumOfWeights;
/** The sum of the values seen */
private double m_SumOfValues;
/** The sum of the values squared */
private double m_SumOfValuesSq;
/** The current mean */
private double m_Mean;
/** The current standard deviation */
private double m_StandardDev;
/** The precision of numeric values ( = minimum std dev permitted) */
private double m_Precision;
/**
* Round a data value using the defined precision for this estimator
*
* @param data the value to round
* @return the rounded data value
*/
private double round(double data) {
return Math.rint(data / m_Precision) * m_Precision;
}
// ===============
// Public methods.
// ===============
/**
* Constructor that takes a precision argument.
*
* @param precision the precision to which numeric values are given. For
* example, if the precision is stated to be 0.1, the values in the
* interval (0.25,0.35] are all treated as 0.3.
*/
public NormalEstimator(double precision) {
m_Precision = precision;
// Allow at most 3 sd's within one interval
m_StandardDev = m_Precision / (2 * 3);
}
/**
* Add a new data value to the current estimator.
*
* @param data the new data value
* @param weight the weight assigned to the data value
*/
@Override
public void addValue(double data, double weight) {
if (weight == 0) {
return;
}
data = round(data);
m_SumOfWeights += weight;
m_SumOfValues += data * weight;
m_SumOfValuesSq += data * data * weight;
computeParameters();
}
/**
* Compute the parameters of the distribution
*/
protected void computeParameters() {
if (m_SumOfWeights > 0) {
m_Mean = m_SumOfValues / m_SumOfWeights;
double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq - m_Mean
* m_SumOfValues)
/ m_SumOfWeights);
// If the stdDev ~= 0, we really have no idea of scale yet,
// so stick with the default. Otherwise...
if (stdDev > 1e-10) {
m_StandardDev = Math.max(m_Precision / (2 * 3),
// allow at most 3sd's within one interval
stdDev);
}
}
}
/**
* Get a probability estimate for a value
*
* @param data the value to estimate the probability of
* @return the estimated probability of the supplied value
*/
@Override
public double getProbability(double data) {
data = round(data);
double zLower = (data - m_Mean - (m_Precision / 2)) / m_StandardDev;
double zUpper = (data - m_Mean + (m_Precision / 2)) / m_StandardDev;
double pLower = Statistics.normalProbability(zLower);
double pUpper = Statistics.normalProbability(zUpper);
return pUpper - pLower;
}
/**
* Display a representation of this estimator
*/
@Override
public String toString() {
return "Normal Distribution. Mean = " + Utils.doubleToString(m_Mean, 4)
+ " StandardDev = " + Utils.doubleToString(m_StandardDev, 4)
+ " WeightSum = " + Utils.doubleToString(m_SumOfWeights, 4)
+ " Precision = " + m_Precision + "\n";
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// class
if (!m_noClass) {
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
} else {
result.enable(Capability.NO_CLASS);
}
// attributes
result.enable(Capability.NUMERIC_ATTRIBUTES);
return result;
}
/**
* Return the value of the mean of this normal estimator.
*
* @return the mean
*/
public double getMean() {
return m_Mean;
}
/**
* Return the value of the standard deviation of this normal estimator.
*
* @return the standard deviation
*/
public double getStdDev() {
return m_StandardDev;
}
/**
* Return the value of the precision of this normal estimator.
*
* @return the precision
*/
public double getPrecision() {
return m_Precision;
}
/**
* Return the sum of the weights for this normal estimator.
*
* @return the sum of the weights
*/
public double getSumOfWeights() {
return m_SumOfWeights;
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 9785 $");
}
@Override
public NormalEstimator aggregate(NormalEstimator toAggregate)
throws Exception {
m_SumOfWeights += toAggregate.m_SumOfWeights;
m_SumOfValues += toAggregate.m_SumOfValues;
m_SumOfValuesSq += toAggregate.m_SumOfValuesSq;
if (toAggregate.m_Precision < m_Precision) {
m_Precision = toAggregate.m_Precision;
}
computeParameters();
return this;
}
@Override
public void finalizeAggregation() throws Exception {
// nothing to do
}
public static void testAggregation() {
NormalEstimator ne = new NormalEstimator(0.01);
NormalEstimator one = new NormalEstimator(0.01);
NormalEstimator two = new NormalEstimator(0.01);
java.util.Random r = new java.util.Random(1);
for (int i = 0; i < 100; i++) {
double z = r.nextDouble();
ne.addValue(z, 1);
if (i < 50) {
one.addValue(z, 1);
} else {
two.addValue(z, 1);
}
}
try {
System.out.println("\n\nFull\n");
System.out.println(ne.toString());
System.out.println("Prob (0): " + ne.getProbability(0));
System.out.println("\nOne\n" + one.toString());
System.out.println("Prob (0): " + one.getProbability(0));
System.out.println("\nTwo\n" + two.toString());
System.out.println("Prob (0): " + two.getProbability(0));
one = one.aggregate(two);
System.out.println("\nAggregated\n");
System.out.println(one.toString());
System.out.println("Prob (0): " + one.getProbability(0));
} catch (Exception ex) {
ex.printStackTrace();
}
}
/**
* Main method for testing this class.
*
* @param argv should contain a sequence of numeric values
*/
public static void main(String[] argv) {
try {
if (argv.length == 0) {
System.out.println("Please specify a set of instances.");
return;
}
NormalEstimator newEst = new NormalEstimator(0.01);
for (int i = 0; i < argv.length; i++) {
double current = Double.valueOf(argv[i]).doubleValue();
System.out.println(newEst);
System.out.println("Prediction for " + current + " = "
+ newEst.getProbability(current));
newEst.addValue(current, 1);
}
NormalEstimator.testAggregation();
} catch (Exception e) {
System.out.println(e.getMessage());
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy