All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.estimators.KernelEstimator Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    KernelEstimator.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.estimators;

import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Aggregateable;
import weka.core.RevisionUtils;
import weka.core.Statistics;
import weka.core.Utils;

/**
 * Simple kernel density estimator. Uses one gaussian kernel per observed data
 * value.
 * 
 * @author Len Trigg ([email protected])
 * @version $Revision: 15521 $
 */
public class KernelEstimator extends Estimator implements IncrementalEstimator,
    Aggregateable {

  /**
   * for serialization
   */
  private static final long serialVersionUID = 3646923563367683925L;

  /**
   * Vector containing all of the values seen
   */
  private double[] m_Values;

  /**
   * Vector containing the associated weights
   */
  private double[] m_Weights;

  /**
   * Number of values stored in m_Weights and m_Values so far
   */
  private int m_NumValues;

  /**
   * The sum of the weights so far
   */
  private double m_SumOfWeights;

  /**
   * The standard deviation
   */
  private double m_StandardDev;

  /**
   * The precision of data values
   */
  private double m_Precision;

  /**
   * Whether we can optimise the kernel summation
   */
  private boolean m_AllWeightsOne;

  /**
   * Maximum percentage error permitted in probability calculations
   */
  private static double MAX_ERROR = 0.01;

  /**
   * Execute a binary search to locate the nearest data value
   *
   * @param key the data value to locate
   * @return the index of the nearest data value
   */
  private int findNearestValue(double key) {

    int low = 0;
    int high = m_NumValues;
    int middle = 0;
    while (low < high) {
      middle = (low + high) / 2;
      double current = m_Values[middle];
      if (current == key) {
        return middle;
      }
      if (current > key) {
        high = middle;
      } else if (current < key) {
        low = middle + 1;
      }
    }
    return low;
  }

  /**
   * Round a data value using the defined precision for this estimator
   *
   * @param data the value to round
   * @return the rounded data value
   */
  private double round(double data) {

    return Math.rint(data / m_Precision) * m_Precision;
  }

  // ===============
  // Public methods.
  // ===============

  /**
   * No-arg constructor needed to make WEKA's forName() work. Uses precision of 0.01.
   */
  public KernelEstimator() {
    this(0.01);
  }

  /**
   * Constructor that takes a precision argument.
   *
   * @param precision the precision to which numeric values are given. For
   *                  example, if the precision is stated to be 0.1, the values in the
   *                  interval (0.25,0.35] are all treated as 0.3.
   */
  public KernelEstimator(double precision) {

    m_Values = new double[50];
    m_Weights = new double[50];
    m_NumValues = 0;
    m_SumOfWeights = 0;
    m_AllWeightsOne = true;
    m_Precision = precision;
    // precision cannot be zero
    if (m_Precision < Utils.SMALL)
      m_Precision = Utils.SMALL;
    // m_StandardDev = 1e10 * m_Precision; // Set the standard deviation
    // initially very wide
    m_StandardDev = m_Precision / (2 * 3);
  }

  /**
   * Add a new data value to the current estimator.
   *
   * @param data   the new data value
   * @param weight the weight assigned to the data value
   */
  @Override
  public void addValue(double data, double weight) {

    if (weight == 0) {
      return;
    }
    data = round(data);
    int insertIndex = findNearestValue(data);
    if ((m_NumValues <= insertIndex) || (m_Values[insertIndex] != data)) {
      if (m_NumValues < m_Values.length) {
        int left = m_NumValues - insertIndex;
        System.arraycopy(m_Values, insertIndex, m_Values, insertIndex + 1, left);
        System.arraycopy(m_Weights, insertIndex, m_Weights, insertIndex + 1, left);

        m_Values[insertIndex] = data;
        m_Weights[insertIndex] = weight;
        m_NumValues++;
      } else {
        double[] newValues = new double[m_Values.length * 2];
        double[] newWeights = new double[m_Values.length * 2];
        int left = m_NumValues - insertIndex;
        System.arraycopy(m_Values, 0, newValues, 0, insertIndex);
        System.arraycopy(m_Weights, 0, newWeights, 0, insertIndex);
        newValues[insertIndex] = data;
        newWeights[insertIndex] = weight;
        System.arraycopy(m_Values, insertIndex, newValues, insertIndex + 1, left);
        System.arraycopy(m_Weights, insertIndex, newWeights, insertIndex + 1, left);
        m_NumValues++;
        m_Values = newValues;
        m_Weights = newWeights;
      }
      if (weight != 1) {
        m_AllWeightsOne = false;
      }
    } else {
      m_Weights[insertIndex] += weight;
      m_AllWeightsOne = false;
    }
    m_SumOfWeights += weight;
    double range = m_Values[m_NumValues - 1] - m_Values[0];
    if (range > 0) {
      m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights),
              // allow at most 3 sds within one interval
              m_Precision / (2 * 3));
    }
  }

  /**
   * Get a probability estimate for a value.
   *
   * @param data the value to estimate the probability of
   * @return the estimated probability of the supplied value
   */
  @Override
  public double getProbability(double data) {

    data = round(data);

    double delta = 0, sum = 0, currentProb = 0;
    double zLower = 0, zUpper = 0;
    if (m_NumValues == 0) {
      zLower = (data - (m_Precision / 2)) / m_StandardDev;
      zUpper = (data + (m_Precision / 2)) / m_StandardDev;
      return (Statistics.normalProbability(zUpper) - Statistics.normalProbability(zLower));
    }
    double weightSum = 0;
    int start = findNearestValue(data);
    for (int i = start; i < m_NumValues; i++) {
      delta = m_Values[i] - data;
      zLower = (delta - (m_Precision / 2)) / m_StandardDev;
      zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
      currentProb = (Statistics.normalProbability(zUpper) - Statistics.normalProbability(zLower));
      sum += currentProb * m_Weights[i];
      /*
       * System.out.print("zL" + (i + 1) + ": " + zLower + " ");
       * System.out.print("zU" + (i + 1) + ": " + zUpper + " ");
       * System.out.print("P" + (i + 1) + ": " + currentProb + " ");
       * System.out.println("total: " + (currentProb * m_Weights[i]) + " ");
       */
      weightSum += m_Weights[i];
      if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
        break;
      }
    }
    for (int i = start - 1; i >= 0; i--) {
      delta = m_Values[i] - data;
      zLower = (delta - (m_Precision / 2)) / m_StandardDev;
      zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
      currentProb = (Statistics.normalProbability(zUpper) - Statistics.normalProbability(zLower));
      sum += currentProb * m_Weights[i];
      weightSum += m_Weights[i];
      if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
        break;
      }
    }
    return sum / (m_SumOfWeights * m_Precision);
  }

  /**
   * Display a representation of this estimator
   */
  @Override
  public String toString() {

    String result = m_NumValues + " Normal Kernels. \nStandardDev = "
            + Utils.doubleToString(m_StandardDev, 6, 4) + " Precision = "
            + m_Precision;
    if (m_NumValues == 0) {
      result += "  \nMean = 0";
    } else {
      result += "  \nMeans =";
      for (int i = 0; i < m_NumValues; i++) {
        result += " " + m_Values[i];
      }
      if (!m_AllWeightsOne) {
        result += "\nWeights = ";
        for (int i = 0; i < m_NumValues; i++) {
          result += " " + m_Weights[i];
        }
      }
    }
    return result + "\n";
  }

  /**
   * Return the number of kernels in this kernel estimator
   *
   * @return the number of kernels
   */
  public int getNumKernels() {
    return m_NumValues;
  }

  /**
   * Return the means of the kernels.
   *
   * @return the means of the kernels
   */
  public double[] getMeans() {
    return m_Values;
  }

  /**
   * Return the weights of the kernels.
   *
   * @return the weights of the kernels
   */
  public double[] getWeights() {
    return m_Weights;
  }

  /**
   * Return the precision of this kernel estimator.
   *
   * @return the precision
   */
  public double getPrecision() {
    return m_Precision;
  }

  /**
   * Return the standard deviation of this kernel estimator.
   *
   * @return the standard deviation
   */
  public double getStdDev() {
    return m_StandardDev;
  }

  /**
   * Returns default capabilities of the classifier.
   *
   * @return the capabilities of this classifier
   */
  @Override
  public Capabilities getCapabilities() {
    Capabilities result = super.getCapabilities();
    result.disableAll();
    // class
    if (!m_noClass) {
      result.enable(Capability.NOMINAL_CLASS);
      result.enable(Capability.MISSING_CLASS_VALUES);
    } else {
      result.enable(Capability.NO_CLASS);
    }

    // attributes
    result.enable(Capability.NUMERIC_ATTRIBUTES);
    return result;
  }

  /**
   * Returns the revision string.
   *
   * @return the revision
   */
  @Override
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 15521 $");
  }

  @Override
  public KernelEstimator aggregate(KernelEstimator toAggregate)
          throws Exception {

    for (int i = 0; i < toAggregate.m_NumValues; i++) {
      addValue(toAggregate.m_Values[i], toAggregate.m_Weights[i]);
    }

    return this;
  }

  @Override
  public void finalizeAggregation() throws Exception {
    // nothing to do
  }

  public static void testAggregation() {
    KernelEstimator ke = new KernelEstimator(0.01);
    KernelEstimator one = new KernelEstimator(0.01);
    KernelEstimator two = new KernelEstimator(0.01);

    java.util.Random r = new java.util.Random(1);

    for (int i = 0; i < 100; i++) {
      double z = r.nextDouble();

      ke.addValue(z, 1);
      if (i < 50) {
        one.addValue(z, 1);
      } else {
        two.addValue(z, 1);
      }
    }

    try {

      System.out.println("\n\nFull\n");
      System.out.println(ke.toString());
      System.out.println("Prob (0): " + ke.getProbability(0));

      System.out.println("\nOne\n" + one.toString());
      System.out.println("Prob (0): " + one.getProbability(0));

      System.out.println("\nTwo\n" + two.toString());
      System.out.println("Prob (0): " + two.getProbability(0));

      one = one.aggregate(two);

      System.out.println("Aggregated\n");
      System.out.println(one.toString());
      System.out.println("Prob (0): " + one.getProbability(0));
    } catch (Exception ex) {
      ex.printStackTrace();
    }
  }

  /**
   * Main method for testing this class.
   *
   * @param argv should contain a sequence of numeric values
   */
  public static void main(String[] argv) {

    try {
      if (argv.length < 2) {
        System.out.println("Please specify a set of instances.");
        return;
      }
      KernelEstimator newEst = new KernelEstimator(0.01);
      for (int i = 0; i < argv.length - 3; i += 2) {
        newEst.addValue(Double.valueOf(argv[i]).doubleValue(),
                Double.valueOf(argv[i + 1]).doubleValue());
      }
      System.out.println(newEst);

      double start = Double.valueOf(argv[argv.length - 2]).doubleValue();
      double finish = Double.valueOf(argv[argv.length - 1]).doubleValue();
      for (double current = start; current < finish; current += (finish - start) / 50) {
        System.out.println("Data: " + current + " "
                + newEst.getProbability(current));
      }

      KernelEstimator.testAggregation();
    } catch (Exception e) {
      System.out.println(e.getMessage());
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy