weka.classifiers.trees.ht.GaussianConditionalSufficientStats Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* GaussianConditionalSufficientStats.java
* Copyright (C) 2013 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees.ht;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import weka.core.Utils;
import weka.estimators.UnivariateNormalEstimator;
/**
* Maintains sufficient stats for a Gaussian distribution for a numeric
* attribute
*
* @author Richard Kirkby ([email protected])
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision: 9705 $
*/
public class GaussianConditionalSufficientStats extends
ConditionalSufficientStats implements Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = -1527915607201784762L;
/**
* Inner class that implements a Gaussian estimator
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
*/
protected class GaussianEstimator extends UnivariateNormalEstimator implements
Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = 4756032800685001315L;
public double getSumOfWeights() {
return m_SumOfWeights;
}
public double probabilityDensity(double value) {
updateMeanAndVariance();
if (m_SumOfWeights > 0) {
double stdDev = Math.sqrt(m_Variance);
if (stdDev > 0) {
double diff = value - m_Mean;
return (1.0 / (CONST * stdDev))
* Math.exp(-(diff * diff / (2.0 * m_Variance)));
}
return value == m_Mean ? 1.0 : 0.0;
}
return 0.0;
}
public double[] weightLessThanEqualAndGreaterThan(double value) {
double stdDev = Math.sqrt(m_Variance);
double equalW = probabilityDensity(value) * m_SumOfWeights;
double lessW = (stdDev > 0) ? weka.core.Statistics
.normalProbability((value - m_Mean) / stdDev)
* m_SumOfWeights
- equalW : (value < m_Mean) ? m_SumOfWeights - equalW : 0.0;
double greaterW = m_SumOfWeights - equalW - lessW;
return new double[] { lessW, equalW, greaterW };
}
}
protected Map m_minValObservedPerClass = new HashMap();
protected Map m_maxValObservedPerClass = new HashMap();
protected int m_numBins = 10;
public void setNumBins(int b) {
m_numBins = b;
}
public int getNumBins() {
return m_numBins;
}
@Override
public void update(double attVal, String classVal, double weight) {
if (!Utils.isMissingValue(attVal)) {
GaussianEstimator norm = (GaussianEstimator) m_classLookup.get(classVal);
if (norm == null) {
norm = new GaussianEstimator();
m_classLookup.put(classVal, norm);
m_minValObservedPerClass.put(classVal, attVal);
m_maxValObservedPerClass.put(classVal, attVal);
} else {
if (attVal < m_minValObservedPerClass.get(classVal)) {
m_minValObservedPerClass.put(classVal, attVal);
}
if (attVal > m_maxValObservedPerClass.get(classVal)) {
m_maxValObservedPerClass.put(classVal, attVal);
}
}
norm.addValue(attVal, weight);
}
}
@Override
public double probabilityOfAttValConditionedOnClass(double attVal,
String classVal) {
GaussianEstimator norm = (GaussianEstimator) m_classLookup.get(classVal);
if (norm == null) {
return 0;
}
// return Utils.lo
return norm.probabilityDensity(attVal);
}
protected TreeSet getSplitPointCandidates() {
TreeSet splits = new TreeSet();
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
for (String classVal : m_classLookup.keySet()) {
if (m_minValObservedPerClass.containsKey(classVal)) {
if (m_minValObservedPerClass.get(classVal) < min) {
min = m_minValObservedPerClass.get(classVal);
}
if (m_maxValObservedPerClass.get(classVal) > max) {
max = m_maxValObservedPerClass.get(classVal);
}
}
}
if (min < Double.POSITIVE_INFINITY) {
double bin = max - min;
bin /= (m_numBins + 1);
for (int i = 0; i < m_numBins; i++) {
double split = min + (bin * (i + 1));
if (split > min && split < max) {
splits.add(split);
}
}
}
return splits;
}
protected List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy