weka.classifiers.bayes.net.search.global.GeneticSearch Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* GeneticSearch.java
* Copyright (C) 2004-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.bayes.net.search.global;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
/**
* This Bayes Network learning algorithm uses genetic
* search for finding a well scoring Bayes network structure. Genetic search
* works by having a population of Bayes network structures and allow them to
* mutate and apply cross over to get offspring. The best network structure
* found during the process is returned.
*
*
*
* Valid options are:
*
*
*
* -L <integer>
* Population size
*
*
*
* -A <integer>
* Descendant population size
*
*
*
* -U <integer>
* Number of runs
*
*
*
* -M
* Use mutation.
* (default true)
*
*
*
* -C
* Use cross-over.
* (default true)
*
*
*
* -O
* Use tournament selection (true) or maximum subpopulatin (false).
* (default false)
*
*
*
* -R <seed>
* Random number seed
*
*
*
* -mbc
* Applies a Markov Blanket correction to the network structure,
* after a network structure is learned. This ensures that all
* nodes in the network are part of the Markov blanket of the
* classifier node.
*
*
*
* -S [LOO-CV|k-Fold-CV|Cumulative-CV]
* Score type (LOO-CV,k-Fold-CV,Cumulative-CV)
*
*
*
* -Q
* Use probabilistic or 0/1 scoring.
* (default probabilistic scoring)
*
*
*
*
* @author Remco Bouckaert ([email protected])
* @version $Revision: 11247 $
*/
public class GeneticSearch extends GlobalScoreSearchAlgorithm {
/** for serialization */
static final long serialVersionUID = 4236165533882462203L;
/** number of runs **/
int m_nRuns = 10;
/** size of population **/
int m_nPopulationSize = 10;
/** size of descendant population **/
int m_nDescendantPopulationSize = 100;
/** use cross-over? **/
boolean m_bUseCrossOver = true;
/** use mutation? **/
boolean m_bUseMutation = true;
/** use tournament selection or take best sub-population **/
boolean m_bUseTournamentSelection = false;
/** random number seed **/
int m_nSeed = 1;
/** random number generator **/
Random m_random = null;
/**
* used in BayesNetRepresentation for efficiently determining whether a number
* is square
*/
boolean[] g_bIsSquare;
class BayesNetRepresentation implements RevisionHandler {
/** number of nodes in network **/
int m_nNodes = 0;
/**
* bit representation of parent sets m_bits[iTail + iHead * m_nNodes]
* represents arc iTail->iHead
*/
boolean[] m_bits;
/** score of represented network structure **/
double m_fScore = 0.0f;
/**
* return score of represented network structure
*
* @return the score
*/
public double getScore() {
return m_fScore;
} // getScore
/**
* c'tor
*
* @param nNodes the number of nodes
*/
BayesNetRepresentation(int nNodes) {
m_nNodes = nNodes;
} // c'tor
/**
* initialize with a random structure by randomly placing m_nNodes arcs.
*/
public void randomInit() {
do {
m_bits = new boolean[m_nNodes * m_nNodes];
for (int i = 0; i < m_nNodes; i++) {
int iPos;
do {
iPos = m_random.nextInt(m_nNodes * m_nNodes);
} while (isSquare(iPos));
m_bits[iPos] = true;
}
} while (hasCycles());
calcGlobalScore();
}
/**
* calculate score of current network representation As a side effect, the
* parent sets are set
*/
void calcGlobalScore() {
// clear current network
for (int iNode = 0; iNode < m_nNodes; iNode++) {
ParentSet parentSet = m_BayesNet.getParentSet(iNode);
while (parentSet.getNrOfParents() > 0) {
parentSet.deleteLastParent(m_BayesNet.m_Instances);
}
}
// insert arrows
for (int iNode = 0; iNode < m_nNodes; iNode++) {
ParentSet parentSet = m_BayesNet.getParentSet(iNode);
for (int iNode2 = 0; iNode2 < m_nNodes; iNode2++) {
if (m_bits[iNode2 + iNode * m_nNodes]) {
parentSet.addParent(iNode2, m_BayesNet.m_Instances);
}
}
}
// calc score
try {
m_fScore = calcScore(m_BayesNet);
} catch (Exception e) {
// ignore
}
} // calcScore
/**
* check whether there are cycles in the network
*
* @return true if a cycle is found, false otherwise
*/
public boolean hasCycles() {
// check for cycles
boolean[] bDone = new boolean[m_nNodes];
for (int iNode = 0; iNode < m_nNodes; iNode++) {
// find a node for which all parents are 'done'
boolean bFound = false;
for (int iNode2 = 0; !bFound && iNode2 < m_nNodes; iNode2++) {
if (!bDone[iNode2]) {
boolean bHasNoParents = true;
for (int iParent = 0; iParent < m_nNodes; iParent++) {
if (m_bits[iParent + iNode2 * m_nNodes] && !bDone[iParent]) {
bHasNoParents = false;
}
}
if (bHasNoParents) {
bDone[iNode2] = true;
bFound = true;
}
}
}
if (!bFound) {
return true;
}
}
return false;
} // hasCycles
/**
* create clone of current object
*
* @return cloned object
*/
BayesNetRepresentation copy() {
BayesNetRepresentation b = new BayesNetRepresentation(m_nNodes);
b.m_bits = new boolean[m_bits.length];
for (int i = 0; i < m_nNodes * m_nNodes; i++) {
b.m_bits[i] = m_bits[i];
}
b.m_fScore = m_fScore;
return b;
} // copy
/**
* Apply mutation operation to BayesNet Calculate score and as a side effect
* sets BayesNet parent sets.
*/
void mutate() {
// flip a bit
do {
int iBit;
do {
iBit = m_random.nextInt(m_nNodes * m_nNodes);
} while (isSquare(iBit));
m_bits[iBit] = !m_bits[iBit];
} while (hasCycles());
calcGlobalScore();
} // mutate
/**
* Apply cross-over operation to BayesNet Calculate score and as a side
* effect sets BayesNet parent sets.
*
* @param other BayesNetRepresentation to cross over with
*/
void crossOver(BayesNetRepresentation other) {
boolean[] bits = new boolean[m_bits.length];
for (int i = 0; i < m_bits.length; i++) {
bits[i] = m_bits[i];
}
int iCrossOverPoint = m_bits.length;
do {
// restore to original state
for (int i = iCrossOverPoint; i < m_bits.length; i++) {
m_bits[i] = bits[i];
}
// take all bits from cross-over points onwards
iCrossOverPoint = m_random.nextInt(m_bits.length);
for (int i = iCrossOverPoint; i < m_bits.length; i++) {
m_bits[i] = other.m_bits[i];
}
} while (hasCycles());
calcGlobalScore();
} // crossOver
/**
* check if number is square and initialize g_bIsSquare structure if
* necessary
*
* @param nNum number to check (should be below m_nNodes * m_nNodes)
* @return true if number is square
*/
boolean isSquare(int nNum) {
if (g_bIsSquare == null || g_bIsSquare.length < nNum) {
g_bIsSquare = new boolean[m_nNodes * m_nNodes];
for (int i = 0; i < m_nNodes; i++) {
g_bIsSquare[i * m_nNodes + i] = true;
}
}
return g_bIsSquare[nNum];
} // isSquare
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 11247 $");
}
} // class BayesNetRepresentation
/**
* search determines the network structure/graph of the network with a genetic
* search algorithm.
*
* @param bayesNet the network to search
* @param instances the instances to use
* @throws Exception if population size doesn fit or neither cross-over or
* mutation was chosen
*/
@Override
protected void search(BayesNet bayesNet, Instances instances)
throws Exception {
// sanity check
if (getDescendantPopulationSize() < getPopulationSize()) {
throw new Exception(
"Descendant PopulationSize should be at least Population Size");
}
if (!getUseCrossOver() && !getUseMutation()) {
throw new Exception(
"At least one of mutation or cross-over should be used");
}
m_random = new Random(m_nSeed);
// keeps track of best structure found so far
BayesNet bestBayesNet;
// keeps track of score pf best structure found so far
double fBestScore = calcScore(bayesNet);
// initialize bestBayesNet
bestBayesNet = new BayesNet();
bestBayesNet.m_Instances = instances;
bestBayesNet.initStructure();
copyParentSets(bestBayesNet, bayesNet);
// initialize population
BayesNetRepresentation[] population = new BayesNetRepresentation[getPopulationSize()];
for (int i = 0; i < getPopulationSize(); i++) {
population[i] = new BayesNetRepresentation(instances.numAttributes());
population[i].randomInit();
if (population[i].getScore() > fBestScore) {
copyParentSets(bestBayesNet, bayesNet);
fBestScore = population[i].getScore();
}
}
// go do the search
for (int iRun = 0; iRun < m_nRuns; iRun++) {
// create descendants
BayesNetRepresentation[] descendantPopulation = new BayesNetRepresentation[getDescendantPopulationSize()];
for (int i = 0; i < getDescendantPopulationSize(); i++) {
descendantPopulation[i] = population[m_random
.nextInt(getPopulationSize())].copy();
if (getUseMutation()) {
if (getUseCrossOver() && m_random.nextBoolean()) {
descendantPopulation[i].crossOver(population[m_random
.nextInt(getPopulationSize())]);
} else {
descendantPopulation[i].mutate();
}
} else {
// use crossover
descendantPopulation[i].crossOver(population[m_random
.nextInt(getPopulationSize())]);
}
if (descendantPopulation[i].getScore() > fBestScore) {
copyParentSets(bestBayesNet, bayesNet);
fBestScore = descendantPopulation[i].getScore();
}
}
// select new population
boolean[] bSelected = new boolean[getDescendantPopulationSize()];
for (int i = 0; i < getPopulationSize(); i++) {
int iSelected = 0;
if (m_bUseTournamentSelection) {
// use tournament selection
iSelected = m_random.nextInt(getDescendantPopulationSize());
while (bSelected[iSelected]) {
iSelected = (iSelected + 1) % getDescendantPopulationSize();
}
int iSelected2 = m_random.nextInt(getDescendantPopulationSize());
while (bSelected[iSelected2]) {
iSelected2 = (iSelected2 + 1) % getDescendantPopulationSize();
}
if (descendantPopulation[iSelected2].getScore() > descendantPopulation[iSelected]
.getScore()) {
iSelected = iSelected2;
}
} else {
// find best scoring network in population
while (bSelected[iSelected]) {
iSelected++;
}
double fScore = descendantPopulation[iSelected].getScore();
for (int j = 0; j < getDescendantPopulationSize(); j++) {
if (!bSelected[j] && descendantPopulation[j].getScore() > fScore) {
fScore = descendantPopulation[j].getScore();
iSelected = j;
}
}
}
population[i] = descendantPopulation[iSelected];
bSelected[iSelected] = true;
}
}
// restore current network to best network
copyParentSets(bayesNet, bestBayesNet);
// free up memory
bestBayesNet = null;
g_bIsSquare = null;
} // search
/**
* copyParentSets copies parent sets of source to dest BayesNet
*
* @param dest destination network
* @param source source network
*/
void copyParentSets(BayesNet dest, BayesNet source) {
int nNodes = source.getNrOfNodes();
// clear parent set first
for (int iNode = 0; iNode < nNodes; iNode++) {
dest.getParentSet(iNode).copy(source.getParentSet(iNode));
}
} // CopyParentSets
/**
* @return number of runs
*/
public int getRuns() {
return m_nRuns;
} // getRuns
/**
* Sets the number of runs
*
* @param nRuns The number of runs to set
*/
public void setRuns(int nRuns) {
m_nRuns = nRuns;
} // setRuns
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration
© 2015 - 2024 Weber Informatics LLC | Privacy Policy