All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.bayes.net.search.local.HillClimber Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 * HillClimber.java
 * Copyright (C) 2004-2012 University of Waikato, Hamilton, New Zealand
 * 
 */

package weka.classifiers.bayes.net.search.local;

import java.io.Serializable;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;

import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;

/**
 *  This Bayes Network learning algorithm uses a hill
 * climbing algorithm adding, deleting and reversing arcs. The search is not
 * restricted by an order on the variables (unlike K2). The difference with B
 * and B2 is that this hill climber also considers arrows part of the naive
 * Bayes structure for deletion.
 * 

* * * Valid options are: *

* *

 * -P <nr of parents>
 *  Maximum number of parents
 * 
* *
 * -R
 *  Use arc reversal operation.
 *  (default false)
 * 
* *
 * -N
 *  Initial structure is empty (instead of Naive Bayes)
 * 
* *
 * -mbc
 *  Applies a Markov Blanket correction to the network structure, 
 *  after a network structure is learned. This ensures that all 
 *  nodes in the network are part of the Markov blanket of the 
 *  classifier node.
 * 
* *
 * -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
 *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)
 * 
* * * * @author Remco Bouckaert ([email protected]) * @version $Revision: 10154 $ */ public class HillClimber extends LocalScoreSearchAlgorithm { /** for serialization */ static final long serialVersionUID = 4322783593818122403L; /** * the Operation class contains info on operations performed on the current * Bayesian network. */ class Operation implements Serializable, RevisionHandler { /** for serialization */ static final long serialVersionUID = -4880888790432547895L; // constants indicating the type of an operation final static int OPERATION_ADD = 0; final static int OPERATION_DEL = 1; final static int OPERATION_REVERSE = 2; /** * c'tor */ public Operation() { } /** * c'tor + initializers * * @param nTail * @param nHead * @param nOperation */ public Operation(int nTail, int nHead, int nOperation) { m_nHead = nHead; m_nTail = nTail; m_nOperation = nOperation; } /** * compare this operation with another * * @param other operation to compare with * @return true if operation is the same */ public boolean equals(Operation other) { if (other == null) { return false; } return ((m_nOperation == other.m_nOperation) && (m_nHead == other.m_nHead) && (m_nTail == other.m_nTail)); } // equals /** number of the tail node **/ public int m_nTail; /** number of the head node **/ public int m_nHead; /** type of operation (ADD, DEL, REVERSE) **/ public int m_nOperation; /** change of score due to this operation **/ public double m_fDeltaScore = -1E100; /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 10154 $"); } } // class Operation /** * cache for remembering the change in score for steps in the search space */ class Cache implements RevisionHandler { /** change in score due to adding an arc **/ double[][] m_fDeltaScoreAdd; /** change in score due to deleting an arc **/ double[][] m_fDeltaScoreDel; /** * c'tor * * @param nNrOfNodes number of nodes in network, used to determine memory * size to reserve */ Cache(int nNrOfNodes) { m_fDeltaScoreAdd = new double[nNrOfNodes][nNrOfNodes]; m_fDeltaScoreDel = new double[nNrOfNodes][nNrOfNodes]; } /** * set cache entry * * @param oOperation operation to perform * @param fValue value to put in cache */ public void put(Operation oOperation, double fValue) { if (oOperation.m_nOperation == Operation.OPERATION_ADD) { m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead] = fValue; } else { m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead] = fValue; } } // put /** * get cache entry * * @param oOperation operation to perform * @return cache value */ public double get(Operation oOperation) { switch (oOperation.m_nOperation) { case Operation.OPERATION_ADD: return m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead]; case Operation.OPERATION_DEL: return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead]; case Operation.OPERATION_REVERSE: return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead] + m_fDeltaScoreAdd[oOperation.m_nHead][oOperation.m_nTail]; } // should never get here return 0; } // get /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 10154 $"); } } // class Cache /** cache for storing score differences **/ Cache m_Cache = null; /** use the arc reversal operator **/ boolean m_bUseArcReversal = false; /** * search determines the network structure/graph of the network with the Taby * algorithm. * * @param bayesNet the network to use * @param instances the data to use * @throws Exception if something goes wrong */ @Override protected void search(BayesNet bayesNet, Instances instances) throws Exception { initCache(bayesNet, instances); // go do the search Operation oOperation = getOptimalOperation(bayesNet, instances); while ((oOperation != null) && (oOperation.m_fDeltaScore > 0)) { performOperation(bayesNet, instances, oOperation); oOperation = getOptimalOperation(bayesNet, instances); } // free up memory m_Cache = null; } // search /** * initCache initializes the cache * * @param bayesNet Bayes network to be learned * @param instances data set to learn from * @throws Exception if something goes wrong */ void initCache(BayesNet bayesNet, Instances instances) throws Exception { // determine base scores double[] fBaseScores = new double[instances.numAttributes()]; int nNrOfAtts = instances.numAttributes(); m_Cache = new Cache(nNrOfAtts); for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) { updateCache(iAttribute, nNrOfAtts, bayesNet.getParentSet(iAttribute)); } for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) { fBaseScores[iAttribute] = calcNodeScore(iAttribute); } for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) { for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) { if (iAttributeHead != iAttributeTail) { Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD); m_Cache.put(oOperation, calcScoreWithExtraParent(iAttributeHead, iAttributeTail) - fBaseScores[iAttributeHead]); } } } } // initCache /** * check whether the operation is not in the forbidden. For base hill climber, * there are no restrictions on operations, so we always return true. * * @param oOperation operation to be checked * @return true if operation is not in the tabu list */ boolean isNotTabu(Operation oOperation) { return true; } // isNotTabu /** * getOptimalOperation finds the optimal operation that can be performed on * the Bayes network that is not in the tabu list. * * @param bayesNet Bayes network to apply operation on * @param instances data set to learn from * @return optimal operation found * @throws Exception if something goes wrong */ Operation getOptimalOperation(BayesNet bayesNet, Instances instances) throws Exception { Operation oBestOperation = new Operation(); // Add??? oBestOperation = findBestArcToAdd(bayesNet, instances, oBestOperation); // Delete??? oBestOperation = findBestArcToDelete(bayesNet, instances, oBestOperation); // Reverse??? if (getUseArcReversal()) { oBestOperation = findBestArcToReverse(bayesNet, instances, oBestOperation); } // did we find something? if (oBestOperation.m_fDeltaScore == -1E100) { return null; } return oBestOperation; } // getOptimalOperation /** * performOperation applies an operation on the Bayes network and update the * cache. * * @param bayesNet Bayes network to apply operation on * @param instances data set to learn from * @param oOperation operation to perform * @throws Exception if something goes wrong */ void performOperation(BayesNet bayesNet, Instances instances, Operation oOperation) throws Exception { // perform operation switch (oOperation.m_nOperation) { case Operation.OPERATION_ADD: applyArcAddition(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances); if (bayesNet.getDebug()) { System.out.print("Add " + oOperation.m_nHead + " -> " + oOperation.m_nTail); } break; case Operation.OPERATION_DEL: applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances); if (bayesNet.getDebug()) { System.out.print("Del " + oOperation.m_nHead + " -> " + oOperation.m_nTail); } break; case Operation.OPERATION_REVERSE: applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances); applyArcAddition(bayesNet, oOperation.m_nTail, oOperation.m_nHead, instances); if (bayesNet.getDebug()) { System.out.print("Rev " + oOperation.m_nHead + " -> " + oOperation.m_nTail); } break; } } // performOperation /** * * @param bayesNet * @param iHead * @param iTail * @param instances */ void applyArcAddition(BayesNet bayesNet, int iHead, int iTail, Instances instances) { ParentSet bestParentSet = bayesNet.getParentSet(iHead); bestParentSet.addParent(iTail, instances); updateCache(iHead, instances.numAttributes(), bestParentSet); } // applyArcAddition /** * * @param bayesNet * @param iHead * @param iTail * @param instances */ void applyArcDeletion(BayesNet bayesNet, int iHead, int iTail, Instances instances) { ParentSet bestParentSet = bayesNet.getParentSet(iHead); bestParentSet.deleteParent(iTail, instances); updateCache(iHead, instances.numAttributes(), bestParentSet); } // applyArcAddition /** * find best (or least bad) arc addition operation * * @param bayesNet Bayes network to add arc to * @param instances data set * @param oBestOperation * @return Operation containing best arc to add, or null if no arc addition is * allowed (this can happen if any arc addition introduces a cycle, or * all parent sets are filled up to the maximum nr of parents). */ Operation findBestArcToAdd(BayesNet bayesNet, Instances instances, Operation oBestOperation) { int nNrOfAtts = instances.numAttributes(); // find best arc to add for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) { if (bayesNet.getParentSet(iAttributeHead).getNrOfParents() < m_nMaxNrOfParents) { for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) { if (addArcMakesSense(bayesNet, instances, iAttributeHead, iAttributeTail)) { Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD); if (m_Cache.get(oOperation) > oBestOperation.m_fDeltaScore) { if (isNotTabu(oOperation)) { oBestOperation = oOperation; oBestOperation.m_fDeltaScore = m_Cache.get(oOperation); } } } } } } return oBestOperation; } // findBestArcToAdd /** * find best (or least bad) arc deletion operation * * @param bayesNet Bayes network to delete arc from * @param instances data set * @param oBestOperation * @return Operation containing best arc to delete, or null if no deletion can * be made (happens when there is no arc in the network yet). */ Operation findBestArcToDelete(BayesNet bayesNet, Instances instances, Operation oBestOperation) { int nNrOfAtts = instances.numAttributes(); // find best arc to delete for (int iNode = 0; iNode < nNrOfAtts; iNode++) { ParentSet parentSet = bayesNet.getParentSet(iNode); for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) { Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_DEL); if (m_Cache.get(oOperation) > oBestOperation.m_fDeltaScore) { if (isNotTabu(oOperation)) { oBestOperation = oOperation; oBestOperation.m_fDeltaScore = m_Cache.get(oOperation); } } } } return oBestOperation; } // findBestArcToDelete /** * find best (or least bad) arc reversal operation * * @param bayesNet Bayes network to reverse arc in * @param instances data set * @param oBestOperation * @return Operation containing best arc to reverse, or null if no reversal is * allowed (happens if there is no arc in the network yet, or when any * such reversal introduces a cycle). */ Operation findBestArcToReverse(BayesNet bayesNet, Instances instances, Operation oBestOperation) { int nNrOfAtts = instances.numAttributes(); // find best arc to reverse for (int iNode = 0; iNode < nNrOfAtts; iNode++) { ParentSet parentSet = bayesNet.getParentSet(iNode); for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) { int iTail = parentSet.getParent(iParent); // is reversal allowed? if (reverseArcMakesSense(bayesNet, instances, iNode, iTail) && bayesNet.getParentSet(iTail).getNrOfParents() < m_nMaxNrOfParents) { // go check if reversal results in the best step forward Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_REVERSE); if (m_Cache.get(oOperation) > oBestOperation.m_fDeltaScore) { if (isNotTabu(oOperation)) { oBestOperation = oOperation; oBestOperation.m_fDeltaScore = m_Cache.get(oOperation); } } } } } return oBestOperation; } // findBestArcToReverse /** * update the cache due to change of parent set of a node * * @param iAttributeHead node that has its parent set changed * @param nNrOfAtts number of nodes/attributes in data set * @param parentSet new parents set of node iAttributeHead */ void updateCache(int iAttributeHead, int nNrOfAtts, ParentSet parentSet) { // update cache entries for arrows heading towards iAttributeHead double fBaseScore = calcNodeScore(iAttributeHead); int nNrOfParents = parentSet.getNrOfParents(); for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) { if (iAttributeTail != iAttributeHead) { if (!parentSet.contains(iAttributeTail)) { // add entries to cache for adding arcs if (nNrOfParents < m_nMaxNrOfParents) { Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD); m_Cache.put(oOperation, calcScoreWithExtraParent(iAttributeHead, iAttributeTail) - fBaseScore); } } else { // add entries to cache for deleting arcs Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_DEL); m_Cache.put(oOperation, calcScoreWithMissingParent(iAttributeHead, iAttributeTail) - fBaseScore); } } } } // updateCache /** * Sets the max number of parents * * @param nMaxNrOfParents the max number of parents */ public void setMaxNrOfParents(int nMaxNrOfParents) { m_nMaxNrOfParents = nMaxNrOfParents; } /** * Gets the max number of parents. * * @return the max number of parents */ public int getMaxNrOfParents() { return m_nMaxNrOfParents; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration




© 2015 - 2025 Weber Informatics LLC | Privacy Policy