All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.bayes.net.search.global.GeneticSearch Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 * GeneticSearch.java
 * Copyright (C) 2004-2012 University of Waikato, Hamilton, New Zealand
 * 
 */

package weka.classifiers.bayes.net.search.global;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;

/**
 *  This Bayes Network learning algorithm uses genetic
 * search for finding a well scoring Bayes network structure. Genetic search
 * works by having a population of Bayes network structures and allow them to
 * mutate and apply cross over to get offspring. The best network structure
 * found during the process is returned.
 * 

* * * Valid options are: *

* *

 * -L <integer>
 *  Population size
 * 
* *
 * -A <integer>
 *  Descendant population size
 * 
* *
 * -U <integer>
 *  Number of runs
 * 
* *
 * -M
 *  Use mutation.
 *  (default true)
 * 
* *
 * -C
 *  Use cross-over.
 *  (default true)
 * 
* *
 * -O
 *  Use tournament selection (true) or maximum subpopulatin (false).
 *  (default false)
 * 
* *
 * -R <seed>
 *  Random number seed
 * 
* *
 * -mbc
 *  Applies a Markov Blanket correction to the network structure, 
 *  after a network structure is learned. This ensures that all 
 *  nodes in the network are part of the Markov blanket of the 
 *  classifier node.
 * 
* *
 * -S [LOO-CV|k-Fold-CV|Cumulative-CV]
 *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)
 * 
* *
 * -Q
 *  Use probabilistic or 0/1 scoring.
 *  (default probabilistic scoring)
 * 
* * * * @author Remco Bouckaert ([email protected]) * @version $Revision: 11247 $ */ public class GeneticSearch extends GlobalScoreSearchAlgorithm { /** for serialization */ static final long serialVersionUID = 4236165533882462203L; /** number of runs **/ int m_nRuns = 10; /** size of population **/ int m_nPopulationSize = 10; /** size of descendant population **/ int m_nDescendantPopulationSize = 100; /** use cross-over? **/ boolean m_bUseCrossOver = true; /** use mutation? **/ boolean m_bUseMutation = true; /** use tournament selection or take best sub-population **/ boolean m_bUseTournamentSelection = false; /** random number seed **/ int m_nSeed = 1; /** random number generator **/ Random m_random = null; /** * used in BayesNetRepresentation for efficiently determining whether a number * is square */ boolean[] g_bIsSquare; class BayesNetRepresentation implements RevisionHandler { /** number of nodes in network **/ int m_nNodes = 0; /** * bit representation of parent sets m_bits[iTail + iHead * m_nNodes] * represents arc iTail->iHead */ boolean[] m_bits; /** score of represented network structure **/ double m_fScore = 0.0f; /** * return score of represented network structure * * @return the score */ public double getScore() { return m_fScore; } // getScore /** * c'tor * * @param nNodes the number of nodes */ BayesNetRepresentation(int nNodes) { m_nNodes = nNodes; } // c'tor /** * initialize with a random structure by randomly placing m_nNodes arcs. */ public void randomInit() { do { m_bits = new boolean[m_nNodes * m_nNodes]; for (int i = 0; i < m_nNodes; i++) { int iPos; do { iPos = m_random.nextInt(m_nNodes * m_nNodes); } while (isSquare(iPos)); m_bits[iPos] = true; } } while (hasCycles()); calcGlobalScore(); } /** * calculate score of current network representation As a side effect, the * parent sets are set */ void calcGlobalScore() { // clear current network for (int iNode = 0; iNode < m_nNodes; iNode++) { ParentSet parentSet = m_BayesNet.getParentSet(iNode); while (parentSet.getNrOfParents() > 0) { parentSet.deleteLastParent(m_BayesNet.m_Instances); } } // insert arrows for (int iNode = 0; iNode < m_nNodes; iNode++) { ParentSet parentSet = m_BayesNet.getParentSet(iNode); for (int iNode2 = 0; iNode2 < m_nNodes; iNode2++) { if (m_bits[iNode2 + iNode * m_nNodes]) { parentSet.addParent(iNode2, m_BayesNet.m_Instances); } } } // calc score try { m_fScore = calcScore(m_BayesNet); } catch (Exception e) { // ignore } } // calcScore /** * check whether there are cycles in the network * * @return true if a cycle is found, false otherwise */ public boolean hasCycles() { // check for cycles boolean[] bDone = new boolean[m_nNodes]; for (int iNode = 0; iNode < m_nNodes; iNode++) { // find a node for which all parents are 'done' boolean bFound = false; for (int iNode2 = 0; !bFound && iNode2 < m_nNodes; iNode2++) { if (!bDone[iNode2]) { boolean bHasNoParents = true; for (int iParent = 0; iParent < m_nNodes; iParent++) { if (m_bits[iParent + iNode2 * m_nNodes] && !bDone[iParent]) { bHasNoParents = false; } } if (bHasNoParents) { bDone[iNode2] = true; bFound = true; } } } if (!bFound) { return true; } } return false; } // hasCycles /** * create clone of current object * * @return cloned object */ BayesNetRepresentation copy() { BayesNetRepresentation b = new BayesNetRepresentation(m_nNodes); b.m_bits = new boolean[m_bits.length]; for (int i = 0; i < m_nNodes * m_nNodes; i++) { b.m_bits[i] = m_bits[i]; } b.m_fScore = m_fScore; return b; } // copy /** * Apply mutation operation to BayesNet Calculate score and as a side effect * sets BayesNet parent sets. */ void mutate() { // flip a bit do { int iBit; do { iBit = m_random.nextInt(m_nNodes * m_nNodes); } while (isSquare(iBit)); m_bits[iBit] = !m_bits[iBit]; } while (hasCycles()); calcGlobalScore(); } // mutate /** * Apply cross-over operation to BayesNet Calculate score and as a side * effect sets BayesNet parent sets. * * @param other BayesNetRepresentation to cross over with */ void crossOver(BayesNetRepresentation other) { boolean[] bits = new boolean[m_bits.length]; for (int i = 0; i < m_bits.length; i++) { bits[i] = m_bits[i]; } int iCrossOverPoint = m_bits.length; do { // restore to original state for (int i = iCrossOverPoint; i < m_bits.length; i++) { m_bits[i] = bits[i]; } // take all bits from cross-over points onwards iCrossOverPoint = m_random.nextInt(m_bits.length); for (int i = iCrossOverPoint; i < m_bits.length; i++) { m_bits[i] = other.m_bits[i]; } } while (hasCycles()); calcGlobalScore(); } // crossOver /** * check if number is square and initialize g_bIsSquare structure if * necessary * * @param nNum number to check (should be below m_nNodes * m_nNodes) * @return true if number is square */ boolean isSquare(int nNum) { if (g_bIsSquare == null || g_bIsSquare.length < nNum) { g_bIsSquare = new boolean[m_nNodes * m_nNodes]; for (int i = 0; i < m_nNodes; i++) { g_bIsSquare[i * m_nNodes + i] = true; } } return g_bIsSquare[nNum]; } // isSquare /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 11247 $"); } } // class BayesNetRepresentation /** * search determines the network structure/graph of the network with a genetic * search algorithm. * * @param bayesNet the network to search * @param instances the instances to use * @throws Exception if population size doesn fit or neither cross-over or * mutation was chosen */ @Override protected void search(BayesNet bayesNet, Instances instances) throws Exception { // sanity check if (getDescendantPopulationSize() < getPopulationSize()) { throw new Exception( "Descendant PopulationSize should be at least Population Size"); } if (!getUseCrossOver() && !getUseMutation()) { throw new Exception( "At least one of mutation or cross-over should be used"); } m_random = new Random(m_nSeed); // keeps track of best structure found so far BayesNet bestBayesNet; // keeps track of score pf best structure found so far double fBestScore = calcScore(bayesNet); // initialize bestBayesNet bestBayesNet = new BayesNet(); bestBayesNet.m_Instances = instances; bestBayesNet.initStructure(); copyParentSets(bestBayesNet, bayesNet); // initialize population BayesNetRepresentation[] population = new BayesNetRepresentation[getPopulationSize()]; for (int i = 0; i < getPopulationSize(); i++) { population[i] = new BayesNetRepresentation(instances.numAttributes()); population[i].randomInit(); if (population[i].getScore() > fBestScore) { copyParentSets(bestBayesNet, bayesNet); fBestScore = population[i].getScore(); } } // go do the search for (int iRun = 0; iRun < m_nRuns; iRun++) { // create descendants BayesNetRepresentation[] descendantPopulation = new BayesNetRepresentation[getDescendantPopulationSize()]; for (int i = 0; i < getDescendantPopulationSize(); i++) { descendantPopulation[i] = population[m_random .nextInt(getPopulationSize())].copy(); if (getUseMutation()) { if (getUseCrossOver() && m_random.nextBoolean()) { descendantPopulation[i].crossOver(population[m_random .nextInt(getPopulationSize())]); } else { descendantPopulation[i].mutate(); } } else { // use crossover descendantPopulation[i].crossOver(population[m_random .nextInt(getPopulationSize())]); } if (descendantPopulation[i].getScore() > fBestScore) { copyParentSets(bestBayesNet, bayesNet); fBestScore = descendantPopulation[i].getScore(); } } // select new population boolean[] bSelected = new boolean[getDescendantPopulationSize()]; for (int i = 0; i < getPopulationSize(); i++) { int iSelected = 0; if (m_bUseTournamentSelection) { // use tournament selection iSelected = m_random.nextInt(getDescendantPopulationSize()); while (bSelected[iSelected]) { iSelected = (iSelected + 1) % getDescendantPopulationSize(); } int iSelected2 = m_random.nextInt(getDescendantPopulationSize()); while (bSelected[iSelected2]) { iSelected2 = (iSelected2 + 1) % getDescendantPopulationSize(); } if (descendantPopulation[iSelected2].getScore() > descendantPopulation[iSelected] .getScore()) { iSelected = iSelected2; } } else { // find best scoring network in population while (bSelected[iSelected]) { iSelected++; } double fScore = descendantPopulation[iSelected].getScore(); for (int j = 0; j < getDescendantPopulationSize(); j++) { if (!bSelected[j] && descendantPopulation[j].getScore() > fScore) { fScore = descendantPopulation[j].getScore(); iSelected = j; } } } population[i] = descendantPopulation[iSelected]; bSelected[iSelected] = true; } } // restore current network to best network copyParentSets(bayesNet, bestBayesNet); // free up memory bestBayesNet = null; g_bIsSquare = null; } // search /** * copyParentSets copies parent sets of source to dest BayesNet * * @param dest destination network * @param source source network */ void copyParentSets(BayesNet dest, BayesNet source) { int nNodes = source.getNrOfNodes(); // clear parent set first for (int iNode = 0; iNode < nNodes; iNode++) { dest.getParentSet(iNode).copy(source.getParentSet(iNode)); } } // CopyParentSets /** * @return number of runs */ public int getRuns() { return m_nRuns; } // getRuns /** * Sets the number of runs * * @param nRuns The number of runs to set */ public void setRuns(int nRuns) { m_nRuns = nRuns; } // setRuns /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration




© 2015 - 2025 Weber Informatics LLC | Privacy Policy