weka.estimators.KernelEstimator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* KernelEstimator.java
* Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
*
*/
package weka.estimators;
import weka.core.Capabilities.Capability;
import weka.core.Capabilities;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.Statistics;
/**
* Simple kernel density estimator. Uses one gaussian kernel per observed
* data value.
*
* @author Len Trigg ([email protected])
* @version $Revision: 5540 $
*/
public class KernelEstimator extends Estimator implements IncrementalEstimator {
/** for serialization */
private static final long serialVersionUID = 3646923563367683925L;
/** Vector containing all of the values seen */
private double [] m_Values;
/** Vector containing the associated weights */
private double [] m_Weights;
/** Number of values stored in m_Weights and m_Values so far */
private int m_NumValues;
/** The sum of the weights so far */
private double m_SumOfWeights;
/** The standard deviation */
private double m_StandardDev;
/** The precision of data values */
private double m_Precision;
/** Whether we can optimise the kernel summation */
private boolean m_AllWeightsOne;
/** Maximum percentage error permitted in probability calculations */
private static double MAX_ERROR = 0.01;
/**
* Execute a binary search to locate the nearest data value
*
* @param the data value to locate
* @return the index of the nearest data value
*/
private int findNearestValue(double key) {
int low = 0;
int high = m_NumValues;
int middle = 0;
while (low < high) {
middle = (low + high) / 2;
double current = m_Values[middle];
if (current == key) {
return middle;
}
if (current > key) {
high = middle;
} else if (current < key) {
low = middle + 1;
}
}
return low;
}
/**
* Round a data value using the defined precision for this estimator
*
* @param data the value to round
* @return the rounded data value
*/
private double round(double data) {
return Math.rint(data / m_Precision) * m_Precision;
}
// ===============
// Public methods.
// ===============
/**
* Constructor that takes a precision argument.
*
* @param precision the precision to which numeric values are given. For
* example, if the precision is stated to be 0.1, the values in the
* interval (0.25,0.35] are all treated as 0.3.
*/
public KernelEstimator(double precision) {
m_Values = new double [50];
m_Weights = new double [50];
m_NumValues = 0;
m_SumOfWeights = 0;
m_AllWeightsOne = true;
m_Precision = precision;
// precision cannot be zero
if (m_Precision < Utils.SMALL) m_Precision = Utils.SMALL;
// m_StandardDev = 1e10 * m_Precision; // Set the standard deviation initially very wide
m_StandardDev = m_Precision / (2 * 3);
}
/**
* Add a new data value to the current estimator.
*
* @param data the new data value
* @param weight the weight assigned to the data value
*/
public void addValue(double data, double weight) {
if (weight == 0) {
return;
}
data = round(data);
int insertIndex = findNearestValue(data);
if ((m_NumValues <= insertIndex) || (m_Values[insertIndex] != data)) {
if (m_NumValues < m_Values.length) {
int left = m_NumValues - insertIndex;
System.arraycopy(m_Values, insertIndex,
m_Values, insertIndex + 1, left);
System.arraycopy(m_Weights, insertIndex,
m_Weights, insertIndex + 1, left);
m_Values[insertIndex] = data;
m_Weights[insertIndex] = weight;
m_NumValues++;
} else {
double [] newValues = new double [m_Values.length * 2];
double [] newWeights = new double [m_Values.length * 2];
int left = m_NumValues - insertIndex;
System.arraycopy(m_Values, 0, newValues, 0, insertIndex);
System.arraycopy(m_Weights, 0, newWeights, 0, insertIndex);
newValues[insertIndex] = data;
newWeights[insertIndex] = weight;
System.arraycopy(m_Values, insertIndex,
newValues, insertIndex + 1, left);
System.arraycopy(m_Weights, insertIndex,
newWeights, insertIndex + 1, left);
m_NumValues++;
m_Values = newValues;
m_Weights = newWeights;
}
if (weight != 1) {
m_AllWeightsOne = false;
}
} else {
m_Weights[insertIndex] += weight;
m_AllWeightsOne = false;
}
m_SumOfWeights += weight;
double range = m_Values[m_NumValues - 1] - m_Values[0];
if (range > 0) {
m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights),
// allow at most 3 sds within one interval
m_Precision / (2 * 3));
}
}
/**
* Get a probability estimate for a value.
*
* @param data the value to estimate the probability of
* @return the estimated probability of the supplied value
*/
public double getProbability(double data) {
double delta = 0, sum = 0, currentProb = 0;
double zLower = 0, zUpper = 0;
if (m_NumValues == 0) {
zLower = (data - (m_Precision / 2)) / m_StandardDev;
zUpper = (data + (m_Precision / 2)) / m_StandardDev;
return (Statistics.normalProbability(zUpper)
- Statistics.normalProbability(zLower));
}
double weightSum = 0;
int start = findNearestValue(data);
for (int i = start; i < m_NumValues; i++) {
delta = m_Values[i] - data;
zLower = (delta - (m_Precision / 2)) / m_StandardDev;
zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
currentProb = (Statistics.normalProbability(zUpper)
- Statistics.normalProbability(zLower));
sum += currentProb * m_Weights[i];
/*
System.out.print("zL" + (i + 1) + ": " + zLower + " ");
System.out.print("zU" + (i + 1) + ": " + zUpper + " ");
System.out.print("P" + (i + 1) + ": " + currentProb + " ");
System.out.println("total: " + (currentProb * m_Weights[i]) + " ");
*/
weightSum += m_Weights[i];
if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
break;
}
}
for (int i = start - 1; i >= 0; i--) {
delta = m_Values[i] - data;
zLower = (delta - (m_Precision / 2)) / m_StandardDev;
zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
currentProb = (Statistics.normalProbability(zUpper)
- Statistics.normalProbability(zLower));
sum += currentProb * m_Weights[i];
weightSum += m_Weights[i];
if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
break;
}
}
return sum / m_SumOfWeights;
}
/** Display a representation of this estimator */
public String toString() {
String result = m_NumValues + " Normal Kernels. \nStandardDev = "
+ Utils.doubleToString(m_StandardDev,6,4)
+ " Precision = " + m_Precision;
if (m_NumValues == 0) {
result += " \nMean = 0";
} else {
result += " \nMeans =";
for (int i = 0; i < m_NumValues; i++) {
result += " " + m_Values[i];
}
if (!m_AllWeightsOne) {
result += "\nWeights = ";
for (int i = 0; i < m_NumValues; i++) {
result += " " + m_Weights[i];
}
}
}
return result + "\n";
}
/**
* Return the number of kernels in this kernel estimator
*
* @return the number of kernels
*/
public int getNumKernels() {
return m_NumValues;
}
/**
* Return the means of the kernels.
*
* @return the means of the kernels
*/
public double[] getMeans() {
return m_Values;
}
/**
* Return the weights of the kernels.
*
* @return the weights of the kernels
*/
public double[] getWeights() {
return m_Weights;
}
/**
* Return the precision of this kernel estimator.
*
* @return the precision
*/
public double getPrecision() {
return m_Precision;
}
/**
* Return the standard deviation of this kernel estimator.
*
* @return the standard deviation
*/
public double getStdDev() {
return m_StandardDev;
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// class
if (!m_noClass) {
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
} else {
result.enable(Capability.NO_CLASS);
}
// attributes
result.enable(Capability.NUMERIC_ATTRIBUTES);
return result;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5540 $");
}
/**
* Main method for testing this class.
*
* @param argv should contain a sequence of numeric values
*/
public static void main(String [] argv) {
try {
if (argv.length < 2) {
System.out.println("Please specify a set of instances.");
return;
}
KernelEstimator newEst = new KernelEstimator(0.01);
for (int i = 0; i < argv.length - 3; i += 2) {
newEst.addValue(Double.valueOf(argv[i]).doubleValue(),
Double.valueOf(argv[i + 1]).doubleValue());
}
System.out.println(newEst);
double start = Double.valueOf(argv[argv.length - 2]).doubleValue();
double finish = Double.valueOf(argv[argv.length - 1]).doubleValue();
for (double current = start; current < finish;
current += (finish - start) / 50) {
System.out.println("Data: " + current + " "
+ newEst.getProbability(current));
}
} catch (Exception e) {
System.out.println(e.getMessage());
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy