
weka.associations.PriorEstimation Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* PriorEstimation.java
* Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
*
*/
package weka.associations;
import java.io.Serializable;
import java.util.Hashtable;
import java.util.Random;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.SpecialFunctions;
import weka.core.Utils;
/**
* Class implementing the prior estimattion of the predictive apriori algorithm
* for mining association rules.
*
* Reference: T. Scheffer (2001). Finding Association Rules That Trade
* Support Optimally against Confidence. Proc of the 5th European Conf. on
* Principles and Practice of Knowledge Discovery in Databases (PKDD'01), pp.
* 424-435. Freiburg, Germany: Springer-Verlag.
*
*
* @author Stefan Mutter ([email protected])
* @version $Revision: 10201 $
*/
public class PriorEstimation implements Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = 5570863216522496271L;
/** The number of rnadom rules. */
protected int m_numRandRules;
/** The number of intervals. */
protected int m_numIntervals;
/** The random seed used for the random rule generation step. */
protected static final int SEED = 0;
/** The maximum number of attributes for which a prior can be estimated. */
protected static final int MAX_N = 1024;
/** The random number generator. */
protected Random m_randNum;
/** The instances for which association rules are mined. */
protected Instances m_instances;
/**
* Flag indicating whether standard association rules or class association
* rules are mined.
*/
protected boolean m_CARs;
/** Hashtable to store the confidence values of randomly generated rules. */
protected Hashtable m_distribution;
/** Hashtable containing the estimated prior probabilities. */
protected Hashtable m_priors;
/** Sums up the confidences of all rules with a certain length. */
protected double m_sum;
/**
* The mid points of the discrete intervals in which the interval [0,1] is
* divided.
*/
protected double[] m_midPoints;
/**
* Constructor
*
* @param instances the instances to be used for generating the associations
* @param numRules the number of random rules used for generating the prior
* @param numIntervals the number of intervals to discretise [0,1]
* @param car flag indicating whether standard or class association rules are
* mined
*/
public PriorEstimation(Instances instances, int numRules, int numIntervals,
boolean car) {
m_instances = instances;
m_CARs = car;
m_numRandRules = numRules;
m_numIntervals = numIntervals;
m_randNum = m_instances.getRandomNumberGenerator(SEED);
}
/**
* Calculates the prior distribution.
*
* @exception Exception if prior can't be estimated successfully
*/
public final void generateDistribution() throws Exception {
boolean jump;
int i, maxLength = m_instances.numAttributes(), ruleCounter;
int[] itemArray;
m_distribution = new Hashtable(maxLength * m_numIntervals);
RuleItem current;
if (m_instances.numAttributes() == 0) {
throw new Exception("Dataset has no attributes!");
}
if (m_instances.numAttributes() >= MAX_N) {
throw new Exception(
"Dataset has to many attributes for prior estimation!");
}
if (m_instances.numInstances() == 0) {
throw new Exception("Dataset has no instances!");
}
for (int h = 0; h < maxLength; h++) {
if (m_instances.attribute(h).isNumeric()) {
throw new Exception("Can't handle numeric attributes!");
}
}
if (m_numIntervals == 0 || m_numRandRules == 0) {
throw new Exception("Prior initialisation impossible");
}
// calculate mid points for the intervals
midPoints();
// create random rules of length i and measure their support and if support
// >0 their confidence
for (i = 1; i <= maxLength; i++) {
m_sum = 0;
int j = 0;
while (j < m_numRandRules) {
jump = false;
if (!m_CARs) {
itemArray = randomRule(maxLength, i, m_randNum);
current = splitItemSet(m_randNum.nextInt(i), itemArray);
} else {
itemArray = randomCARule(maxLength, i, m_randNum);
current = addCons(itemArray);
}
int[] ruleItem = new int[maxLength];
for (int k = 0; k < itemArray.length; k++) {
if (current.m_premise.m_items[k] != -1) {
ruleItem[k] = current.m_premise.m_items[k];
} else if (current.m_consequence.m_items[k] != -1) {
ruleItem[k] = current.m_consequence.m_items[k];
} else {
ruleItem[k] = -1;
}
}
ItemSet rule = new ItemSet(ruleItem);
updateCounters(rule);
ruleCounter = rule.m_counter;
if (ruleCounter > 0) {
jump = true;
}
updateCounters(current.m_premise);
j++;
if (jump) {
buildDistribution((double) ruleCounter
/ (double) current.m_premise.m_counter, i);
}
}
// normalize
if (m_sum > 0) {
for (double m_midPoint : m_midPoints) {
String key = (String.valueOf(m_midPoint)).concat(String
.valueOf((double) i));
Double oldValue = m_distribution.remove(key);
if (oldValue == null) {
m_distribution.put(key, new Double(1.0 / m_numIntervals));
m_sum += 1.0 / m_numIntervals;
} else {
m_distribution.put(key, oldValue);
}
}
for (double m_midPoint : m_midPoints) {
double conf = 0;
String key = (String.valueOf(m_midPoint)).concat(String
.valueOf((double) i));
Double oldValue = m_distribution.remove(key);
if (oldValue != null) {
conf = oldValue.doubleValue() / m_sum;
m_distribution.put(key, new Double(conf));
}
}
} else {
for (double m_midPoint : m_midPoints) {
String key = (String.valueOf(m_midPoint)).concat(String
.valueOf((double) i));
m_distribution.put(key, new Double(1.0 / m_numIntervals));
}
}
}
}
/**
* Constructs an item set of certain length randomly. This method is used for
* standard association rule mining.
*
* @param maxLength the number of attributes of the instances
* @param actualLength the number of attributes that should be present in the
* item set
* @param randNum the random number generator
* @return a randomly constructed item set in form of an int array
*/
public final int[] randomRule(int maxLength, int actualLength, Random randNum) {
int[] itemArray = new int[maxLength];
for (int k = 0; k < itemArray.length; k++) {
itemArray[k] = -1;
}
int help = actualLength;
if (help == maxLength) {
help = 0;
for (int h = 0; h < itemArray.length; h++) {
itemArray[h] = m_randNum
.nextInt((m_instances.attribute(h)).numValues());
}
}
while (help > 0) {
int mark = randNum.nextInt(maxLength);
if (itemArray[mark] == -1) {
help--;
itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark))
.numValues());
}
}
return itemArray;
}
/**
* Constructs an item set of certain length randomly. This method is used for
* class association rule mining.
*
* @param maxLength the number of attributes of the instances
* @param actualLength the number of attributes that should be present in the
* item set
* @param randNum the random number generator
* @return a randomly constructed item set in form of an int array
*/
public final int[] randomCARule(int maxLength, int actualLength,
Random randNum) {
int[] itemArray = new int[maxLength];
for (int k = 0; k < itemArray.length; k++) {
itemArray[k] = -1;
}
if (actualLength == 1) {
return itemArray;
}
int help = actualLength - 1;
if (help == maxLength - 1) {
help = 0;
for (int h = 0; h < itemArray.length; h++) {
if (h != m_instances.classIndex()) {
itemArray[h] = m_randNum.nextInt((m_instances.attribute(h))
.numValues());
}
}
}
while (help > 0) {
int mark = randNum.nextInt(maxLength);
if (itemArray[mark] == -1 && mark != m_instances.classIndex()) {
help--;
itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark))
.numValues());
}
}
return itemArray;
}
/**
* updates the distribution of the confidence values. For every confidence
* value the interval to which it belongs is searched and the confidence is
* added to the confidence already found in this interval.
*
* @param conf the confidence of the randomly created rule
* @param length the legnth of the randomly created rule
*/
public final void buildDistribution(double conf, double length) {
double mPoint = findIntervall(conf);
String key = (String.valueOf(mPoint)).concat(String.valueOf(length));
m_sum += conf;
Double oldValue = m_distribution.remove(key);
if (oldValue != null) {
conf = conf + oldValue.doubleValue();
}
m_distribution.put(key, new Double(conf));
}
/**
* searches the mid point of the interval a given confidence value falls into
*
* @param conf the confidence of a rule
* @return the mid point of the interval the confidence belongs to
*/
public final double findIntervall(double conf) {
if (conf == 1.0) {
return m_midPoints[m_midPoints.length - 1];
}
int end = m_midPoints.length - 1;
int start = 0;
while (Math.abs(end - start) > 1) {
int mid = (start + end) / 2;
if (conf > m_midPoints[mid]) {
start = mid + 1;
}
if (conf < m_midPoints[mid]) {
end = mid - 1;
}
if (conf == m_midPoints[mid]) {
return m_midPoints[mid];
}
}
if (Math.abs(conf - m_midPoints[start]) <= Math
.abs(conf - m_midPoints[end])) {
return m_midPoints[start];
} else {
return m_midPoints[end];
}
}
/**
* calculates the numerator and the denominator of the prior equation
*
* @param weighted indicates whether the numerator or the denominator is
* calculated
* @param mPoint the mid Point of an interval
* @return the numerator or denominator of the prior equation
*/
public final double calculatePriorSum(boolean weighted, double mPoint) {
double distr, sum = 0, max = logbinomialCoefficient(
m_instances.numAttributes(), m_instances.numAttributes() / 2);
for (int i = 1; i <= m_instances.numAttributes(); i++) {
if (weighted) {
String key = (String.valueOf(mPoint))
.concat(String.valueOf((double) i));
Double hashValue = m_distribution.get(key);
if (hashValue != null) {
distr = hashValue.doubleValue();
} else {
distr = 0;
}
// distr = 1.0/m_numIntervals;
if (distr != 0) {
double addend = Utils.log2(distr) - max
+ Utils.log2((Math.pow(2, i) - 1))
+ logbinomialCoefficient(m_instances.numAttributes(), i);
sum = sum + Math.pow(2, addend);
}
} else {
double addend = Utils.log2((Math.pow(2, i) - 1)) - max
+ logbinomialCoefficient(m_instances.numAttributes(), i);
sum = sum + Math.pow(2, addend);
}
}
return sum;
}
/**
* Method that calculates the base 2 logarithm of a binomial coefficient
*
* @param upperIndex upper Inedx of the binomial coefficient
* @param lowerIndex lower index of the binomial coefficient
* @return the base 2 logarithm of the binomial coefficient
*/
public static final double logbinomialCoefficient(int upperIndex,
int lowerIndex) {
double result = 1.0;
if (upperIndex == lowerIndex || lowerIndex == 0) {
return result;
}
result = SpecialFunctions.log2Binomial(upperIndex, lowerIndex);
return result;
}
/**
* Method to estimate the prior probabilities
*
* @throws Exception throws exception if the prior cannot be calculated
* @return a hashtable containing the prior probabilities
*/
public final Hashtable estimatePrior() throws Exception {
double prior, denominator, mPoint;
Hashtable m_priors = new Hashtable(
m_numIntervals);
denominator = calculatePriorSum(false, 1.0);
generateDistribution();
for (int i = 0; i < m_numIntervals; i++) {
mPoint = m_midPoints[i];
prior = calculatePriorSum(true, mPoint) / denominator;
m_priors.put(new Double(mPoint), new Double(prior));
}
return m_priors;
}
/**
* split the interval [0,1] into a predefined number of intervals and
* calculates their mid points
*/
public final void midPoints() {
m_midPoints = new double[m_numIntervals];
for (int i = 0; i < m_numIntervals; i++) {
m_midPoints[i] = midPoint(1.0 / m_numIntervals, i);
}
}
/**
* calculates the mid point of an interval
*
* @param size the size of each interval
* @param number the number of the interval. The intervals are numbered from 0
* to m_numIntervals.
* @return the mid point of the interval
*/
public double midPoint(double size, int number) {
return (size * number) + (size / 2.0);
}
/**
* returns an ordered array of all mid points
*
* @return an ordered array of doubles conatining all midpoints
*/
public final double[] getMidPoints() {
return m_midPoints;
}
/**
* splits an item set into premise and consequence and constructs therefore an
* association rule. The length of the premise is given. The attributes for
* premise and consequence are chosen randomly. The result is a RuleItem.
*
* @param premiseLength the length of the premise
* @param itemArray a (randomly generated) item set
* @return a randomly generated association rule stored in a RuleItem
*/
public final RuleItem splitItemSet(int premiseLength, int[] itemArray) {
int[] cons = new int[m_instances.numAttributes()];
System.arraycopy(itemArray, 0, cons, 0, itemArray.length);
int help = premiseLength;
while (help > 0) {
int mark = m_randNum.nextInt(itemArray.length);
if (cons[mark] != -1) {
help--;
cons[mark] = -1;
}
}
if (premiseLength == 0) {
for (int i = 0; i < itemArray.length; i++) {
itemArray[i] = -1;
}
} else {
for (int i = 0; i < itemArray.length; i++) {
if (cons[i] != -1) {
itemArray[i] = -1;
}
}
}
ItemSet premise = new ItemSet(itemArray);
ItemSet consequence = new ItemSet(cons);
RuleItem current = new RuleItem();
current.m_premise = premise;
current.m_consequence = consequence;
return current;
}
/**
* generates a class association rule out of a given premise. It randomly
* chooses a class label as consequence.
*
* @param itemArray the (randomly constructed) premise of the class
* association rule
* @return a class association rule stored in a RuleItem
*/
public final RuleItem addCons(int[] itemArray) {
ItemSet premise = new ItemSet(itemArray);
int[] cons = new int[itemArray.length];
for (int i = 0; i < itemArray.length; i++) {
cons[i] = -1;
}
cons[m_instances.classIndex()] = m_randNum.nextInt((m_instances
.attribute(m_instances.classIndex())).numValues());
ItemSet consequence = new ItemSet(cons);
RuleItem current = new RuleItem();
current.m_premise = premise;
current.m_consequence = consequence;
return current;
}
/**
* updates the support count of an item set
*
* @param itemSet the item set
*/
public final void updateCounters(ItemSet itemSet) {
for (int i = 0; i < m_instances.numInstances(); i++) {
itemSet.upDateCounter(m_instances.instance(i));
}
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10201 $");
}
}