weka.classifiers.bayes.net.search.global.GlobalScoreSearchAlgorithm Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* GlobalScoreSearchAlgorithm.java
* Copyright (C) 2004-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.bayes.net.search.global;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
/**
* This Bayes Network learning algorithm uses cross
* validation to estimate classification accuracy.
*
*
*
* Valid options are:
*
*
*
* -mbc
* Applies a Markov Blanket correction to the network structure,
* after a network structure is learned. This ensures that all
* nodes in the network are part of the Markov blanket of the
* classifier node.
*
*
*
* -S [LOO-CV|k-Fold-CV|Cumulative-CV]
* Score type (LOO-CV,k-Fold-CV,Cumulative-CV)
*
*
*
* -Q
* Use probabilistic or 0/1 scoring.
* (default probabilistic scoring)
*
*
*
*
* @author Remco Bouckaert
* @version $Revision: 10154 $
*/
public class GlobalScoreSearchAlgorithm extends SearchAlgorithm {
/** for serialization */
static final long serialVersionUID = 7341389867906199781L;
/** points to Bayes network for which a structure is searched for **/
BayesNet m_BayesNet;
/**
* toggle between scoring using accuracy = 0-1 loss (when false) or class
* probabilities (when true)
**/
boolean m_bUseProb = true;
/** number of folds for k-fold cross validation **/
int m_nNrOfFolds = 10;
/** constant for score type: LOO-CV */
final static int LOOCV = 0;
/** constant for score type: k-fold-CV */
final static int KFOLDCV = 1;
/** constant for score type: Cumulative-CV */
final static int CUMCV = 2;
/** the score types **/
public static final Tag[] TAGS_CV_TYPE = { new Tag(LOOCV, "LOO-CV"),
new Tag(KFOLDCV, "k-Fold-CV"), new Tag(CUMCV, "Cumulative-CV") };
/**
* Holds the cross validation strategy used to measure quality of network
*/
int m_nCVType = LOOCV;
/**
* performCV returns the accuracy calculated using cross validation. The
* dataset used is m_Instances associated with the Bayes Network.
*
* @param bayesNet : Bayes Network containing structure to evaluate
* @return accuracy (in interval 0..1) measured using cv.
* @throws Exception whn m_nCVType is invalided + exceptions passed on by
* updateClassifier
*/
public double calcScore(BayesNet bayesNet) throws Exception {
switch (m_nCVType) {
case LOOCV:
return leaveOneOutCV(bayesNet);
case CUMCV:
return cumulativeCV(bayesNet);
case KFOLDCV:
return kFoldCV(bayesNet, m_nNrOfFolds);
default:
throw new Exception("Unrecognized cross validation type encountered: "
+ m_nCVType);
}
} // calcScore
/**
* Calc Node Score With Added Parent
*
* @param nNode node for which the score is calculate
* @param nCandidateParent candidate parent to add to the existing parent set
* @return log score
* @throws Exception if something goes wrong
*/
public double calcScoreWithExtraParent(int nNode, int nCandidateParent)
throws Exception {
ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
Instances instances = m_BayesNet.m_Instances;
// sanity check: nCandidateParent should not be in parent set already
for (int iParent = 0; iParent < oParentSet.getNrOfParents(); iParent++) {
if (oParentSet.getParent(iParent) == nCandidateParent) {
return -1e100;
}
}
// set up candidate parent
oParentSet.addParent(nCandidateParent, instances);
// calculate the score
double fAccuracy = calcScore(m_BayesNet);
// delete temporarily added parent
oParentSet.deleteLastParent(instances);
return fAccuracy;
} // calcScoreWithExtraParent
/**
* Calc Node Score With Parent Deleted
*
* @param nNode node for which the score is calculate
* @param nCandidateParent candidate parent to delete from the existing parent
* set
* @return log score
* @throws Exception if something goes wrong
*/
public double calcScoreWithMissingParent(int nNode, int nCandidateParent)
throws Exception {
ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
Instances instances = m_BayesNet.m_Instances;
// sanity check: nCandidateParent should be in parent set already
if (!oParentSet.contains(nCandidateParent)) {
return -1e100;
}
// set up candidate parent
int iParent = oParentSet.deleteParent(nCandidateParent, instances);
// calculate the score
double fAccuracy = calcScore(m_BayesNet);
// reinsert temporarily deleted parent
oParentSet.addParent(nCandidateParent, iParent, instances);
return fAccuracy;
} // calcScoreWithMissingParent
/**
* Calc Node Score With Arrow reversed
*
* @param nNode node for which the score is calculate
* @param nCandidateParent candidate parent to delete from the existing parent
* set
* @return log score
* @throws Exception if something goes wrong
*/
public double calcScoreWithReversedParent(int nNode, int nCandidateParent)
throws Exception {
ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
ParentSet oParentSet2 = m_BayesNet.getParentSet(nCandidateParent);
Instances instances = m_BayesNet.m_Instances;
// sanity check: nCandidateParent should be in parent set already
if (!oParentSet.contains(nCandidateParent)) {
return -1e100;
}
// set up candidate parent
int iParent = oParentSet.deleteParent(nCandidateParent, instances);
oParentSet2.addParent(nNode, instances);
// calculate the score
double fAccuracy = calcScore(m_BayesNet);
// restate temporarily reversed arrow
oParentSet2.deleteLastParent(instances);
oParentSet.addParent(nCandidateParent, iParent, instances);
return fAccuracy;
} // calcScoreWithReversedParent
/**
* LeaveOneOutCV returns the accuracy calculated using Leave One Out cross
* validation. The dataset used is m_Instances associated with the Bayes
* Network.
*
* @param bayesNet : Bayes Network containing structure to evaluate
* @return accuracy (in interval 0..1) measured using leave one out cv.
* @throws Exception passed on by updateClassifier
*/
public double leaveOneOutCV(BayesNet bayesNet) throws Exception {
m_BayesNet = bayesNet;
double fAccuracy = 0.0;
double fWeight = 0.0;
Instances instances = bayesNet.m_Instances;
bayesNet.estimateCPTs();
for (int iInstance = 0; iInstance < instances.numInstances(); iInstance++) {
Instance instance = instances.instance(iInstance);
instance.setWeight(-instance.weight());
bayesNet.updateClassifier(instance);
fAccuracy += accuracyIncrease(instance);
fWeight += instance.weight();
instance.setWeight(-instance.weight());
bayesNet.updateClassifier(instance);
}
return fAccuracy / fWeight;
} // LeaveOneOutCV
/**
* CumulativeCV returns the accuracy calculated using cumulative cross
* validation. The idea is to run through the data set and try to classify
* each of the instances based on the previously seen data. The data set used
* is m_Instances associated with the Bayes Network.
*
* @param bayesNet : Bayes Network containing structure to evaluate
* @return accuracy (in interval 0..1) measured using leave one out cv.
* @throws Exception passed on by updateClassifier
*/
public double cumulativeCV(BayesNet bayesNet) throws Exception {
m_BayesNet = bayesNet;
double fAccuracy = 0.0;
double fWeight = 0.0;
Instances instances = bayesNet.m_Instances;
bayesNet.initCPTs();
for (int iInstance = 0; iInstance < instances.numInstances(); iInstance++) {
Instance instance = instances.instance(iInstance);
fAccuracy += accuracyIncrease(instance);
bayesNet.updateClassifier(instance);
fWeight += instance.weight();
}
return fAccuracy / fWeight;
} // LeaveOneOutCV
/**
* kFoldCV uses k-fold cross validation to measure the accuracy of a Bayes
* network classifier.
*
* @param bayesNet : Bayes Network containing structure to evaluate
* @param nNrOfFolds : the number of folds k to perform k-fold cv
* @return accuracy (in interval 0..1) measured using leave one out cv.
* @throws Exception passed on by updateClassifier
*/
public double kFoldCV(BayesNet bayesNet, int nNrOfFolds) throws Exception {
m_BayesNet = bayesNet;
double fAccuracy = 0.0;
double fWeight = 0.0;
Instances instances = bayesNet.m_Instances;
// estimate CPTs based on complete data set
bayesNet.estimateCPTs();
int nFoldStart = 0;
int nFoldEnd = instances.numInstances() / nNrOfFolds;
int iFold = 1;
while (nFoldStart < instances.numInstances()) {
// remove influence of fold iFold from the probability distribution
for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
Instance instance = instances.instance(iInstance);
instance.setWeight(-instance.weight());
bayesNet.updateClassifier(instance);
}
// measure accuracy on fold iFold
for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
Instance instance = instances.instance(iInstance);
instance.setWeight(-instance.weight());
fAccuracy += accuracyIncrease(instance);
fWeight += instance.weight();
instance.setWeight(-instance.weight());
}
// restore influence of fold iFold from the probability distribution
for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
Instance instance = instances.instance(iInstance);
instance.setWeight(-instance.weight());
bayesNet.updateClassifier(instance);
}
// go to next fold
nFoldStart = nFoldEnd;
iFold++;
nFoldEnd = iFold * instances.numInstances() / nNrOfFolds;
}
return fAccuracy / fWeight;
} // kFoldCV
/**
* accuracyIncrease determines how much the accuracy estimate should be
* increased due to the contribution of a single given instance.
*
* @param instance : instance for which to calculate the accuracy increase.
* @return increase in accuracy due to given instance.
* @throws Exception passed on by distributionForInstance and classifyInstance
*/
double accuracyIncrease(Instance instance) throws Exception {
if (m_bUseProb) {
double[] fProb = m_BayesNet.distributionForInstance(instance);
return fProb[(int) instance.classValue()] * instance.weight();
} else {
if (m_BayesNet.classifyInstance(instance) == instance.classValue()) {
return instance.weight();
}
}
return 0;
} // accuracyIncrease
/**
* @return use probabilities or not in accuracy estimate
*/
public boolean getUseProb() {
return m_bUseProb;
} // getUseProb
/**
* @param useProb : use probabilities or not in accuracy estimate
*/
public void setUseProb(boolean useProb) {
m_bUseProb = useProb;
} // setUseProb
/**
* set cross validation strategy to be used in searching for networks.
*
* @param newCVType : cross validation strategy
*/
public void setCVType(SelectedTag newCVType) {
if (newCVType.getTags() == TAGS_CV_TYPE) {
m_nCVType = newCVType.getSelectedTag().getID();
}
} // setCVType
/**
* get cross validation strategy to be used in searching for networks.
*
* @return cross validation strategy
*/
public SelectedTag getCVType() {
return new SelectedTag(m_nCVType, TAGS_CV_TYPE);
} // getCVType
/**
*
* @param bMarkovBlanketClassifier
*/
@Override
public void setMarkovBlanketClassifier(boolean bMarkovBlanketClassifier) {
super.setMarkovBlanketClassifier(bMarkovBlanketClassifier);
}
/**
*
* @return
*/
@Override
public boolean getMarkovBlanketClassifier() {
return super.getMarkovBlanketClassifier();
}
/**
* Returns an enumeration describing the available options
*
* @return an enumeration of all the available options
*/
@Override
public Enumeration