All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.bayes.BayesNet Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 * BayesNet.java
 * Copyright (C) 2001-2012 University of Waikato, Hamilton, New Zealand
 * 
 */
package weka.classifiers.bayes;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.bayes.net.ADNode;
import weka.classifiers.bayes.net.BIFReader;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.estimate.BayesNetEstimator;
import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;
import weka.classifiers.bayes.net.estimate.SimpleEstimator;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.classifiers.bayes.net.search.local.K2;
import weka.classifiers.bayes.net.search.local.LocalScoreSearchAlgorithm;
import weka.classifiers.bayes.net.search.local.Scoreable;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.estimators.Estimator;
import weka.filters.Filter;
import weka.filters.supervised.attribute.Discretize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/**
 *  Bayes Network learning using various search
 * algorithms and quality measures.
* Base class for a Bayes Network classifier. Provides datastructures (network * structure, conditional probability distributions, etc.) and facilities common * to Bayes Network learning algorithms like K2 and B.
*
* For more information see:
*
* http://sourceforge.net/projects/weka/files/documentation/WekaManual-3-7-0.pdf * /download *

* * * Valid options are: *

* *

 * -D
 *  Do not use ADTree data structure
 * 
* *
 * -B <BIF file>
 *  BIF file to compare with
 * 
* *
 * -Q weka.classifiers.bayes.net.search.SearchAlgorithm
 *  Search algorithm
 * 
* *
 * -E weka.classifiers.bayes.net.estimate.SimpleEstimator
 *  Estimator algorithm
 * 
* * * * @author Remco Bouckaert ([email protected]) * @version $Revision: 13308 $ */ public class BayesNet extends AbstractClassifier implements OptionHandler, WeightedInstancesHandler, Drawable, AdditionalMeasureProducer { /** for serialization */ static final long serialVersionUID = 746037443258775954L; /** * The parent sets. */ protected ParentSet[] m_ParentSets; /** * The attribute estimators containing CPTs. */ public Estimator[][] m_Distributions; /** filter used to quantize continuous variables, if any **/ protected Discretize m_DiscretizeFilter = null; /** attribute index of a non-nominal attribute */ int m_nNonDiscreteAttribute = -1; /** filter used to fill in missing values, if any **/ protected ReplaceMissingValues m_MissingValuesFilter = null; /** * The number of classes */ protected int m_NumClasses; /** * The dataset header for the purposes of printing out a semi-intelligible * model */ public Instances m_Instances; /** * The number of instances the model was built from */ private int m_NumInstances; /** * Datastructure containing ADTree representation of the database. This may * result in more efficient access to the data. */ ADNode m_ADTree; /** * Bayes network to compare the structure with. */ protected BIFReader m_otherBayesNet = null; /** * Use the experimental ADTree datastructure for calculating contingency * tables */ boolean m_bUseADTree = false; /** * Search algorithm used for learning the structure of a network. */ SearchAlgorithm m_SearchAlgorithm = new K2(); /** * Search algorithm used for learning the structure of a network. */ BayesNetEstimator m_BayesNetEstimator = new SimpleEstimator(); /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // instances result.setMinimumNumberInstances(0); return result; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ @Override public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); // ensure we have a data set with discrete variables only and with no // missing values instances = normalizeDataSet(instances); // Copy the instances m_Instances = new Instances(instances); m_NumInstances = m_Instances.numInstances(); // sanity check: need more than 1 variable in datat set m_NumClasses = instances.numClasses(); // initialize ADTree if (m_bUseADTree) { m_ADTree = ADNode.makeADTree(instances); // System.out.println("Oef, done!"); } // build the network structure initStructure(); // build the network structure buildStructure(); // build the set of CPTs estimateCPTs(); // Save space m_Instances = new Instances(m_Instances, 0); m_ADTree = null; } // buildClassifier /** * Returns the number of instances the model was built from. */ public int getNumInstances() { return m_NumInstances; } /** * ensure that all variables are nominal and that there are no missing values * * @param instances data set to check and quantize and/or fill in missing * values * @return filtered instances * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails */ protected Instances normalizeDataSet(Instances instances) throws Exception { m_nNonDiscreteAttribute = -1; Enumeration enu = instances.enumerateAttributes(); while (enu.hasMoreElements()) { Attribute attribute = enu.nextElement(); if (attribute.type() != Attribute.NOMINAL) { m_nNonDiscreteAttribute = attribute.index(); } } if ((m_nNonDiscreteAttribute > -1) && (instances.attribute(m_nNonDiscreteAttribute).type() != Attribute.NOMINAL)) { m_DiscretizeFilter = new Discretize(); m_DiscretizeFilter.setInputFormat(instances); instances = Filter.useFilter(instances, m_DiscretizeFilter); } m_MissingValuesFilter = new ReplaceMissingValues(); m_MissingValuesFilter.setInputFormat(instances); instances = Filter.useFilter(instances, m_MissingValuesFilter); return instances; } // normalizeDataSet /** * ensure that all variables are nominal and that there are no missing values * * @param instance instance to check and quantize and/or fill in missing * values * @return filtered instance * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails */ protected Instance normalizeInstance(Instance instance) throws Exception { if ((m_nNonDiscreteAttribute > -1) && (instance.attribute(m_nNonDiscreteAttribute).type() != Attribute.NOMINAL)) { m_DiscretizeFilter.input(instance); instance = m_DiscretizeFilter.output(); } m_MissingValuesFilter.input(instance); instance = m_MissingValuesFilter.output(); return instance; } // normalizeInstance /** * Init structure initializes the structure to an empty graph or a Naive Bayes * graph (depending on the -N flag). * * @throws Exception in case of an error */ public void initStructure() throws Exception { // initialize topological ordering // m_nOrder = new int[m_Instances.numAttributes()]; // m_nOrder[0] = m_Instances.classIndex(); int nAttribute = 0; for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) { if (nAttribute == m_Instances.classIndex()) { nAttribute++; } // m_nOrder[iOrder] = nAttribute++; } // reserve memory m_ParentSets = new ParentSet[m_Instances.numAttributes()]; for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { m_ParentSets[iAttribute] = new ParentSet(m_Instances.numAttributes()); } } // initStructure /** * buildStructure determines the network structure/graph of the network. The * default behavior is creating a network where all nodes have the first node * as its parent (i.e., a BayesNet that behaves like a naive Bayes * classifier). This method can be overridden by derived classes to restrict * the class of network structures that are acceptable. * * @throws Exception in case of an error */ public void buildStructure() throws Exception { m_SearchAlgorithm.buildStructure(this, m_Instances); } // buildStructure /** * estimateCPTs estimates the conditional probability tables for the Bayes Net * using the network structure. * * @throws Exception in case of an error */ public void estimateCPTs() throws Exception { m_BayesNetEstimator.estimateCPTs(this); } // estimateCPTs /** * initializes the conditional probabilities * * @throws Exception in case of an error */ public void initCPTs() throws Exception { m_BayesNetEstimator.initCPTs(this); } // estimateCPTs /** * Updates the classifier with the given instance. * * @param instance the new training instance to include in the model * @throws Exception if the instance could not be incorporated in the model. */ public void updateClassifier(Instance instance) throws Exception { instance = normalizeInstance(instance); m_BayesNetEstimator.updateClassifier(this, instance); } // updateClassifier /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if there is a problem generating the prediction */ @Override public double[] distributionForInstance(Instance instance) throws Exception { instance = normalizeInstance(instance); return m_BayesNetEstimator.distributionForInstance(this, instance); } // distributionForInstance /** * Calculates the counts for Dirichlet distribution for the class membership * probabilities for the given test instance. * * @param instance the instance to be classified * @return counts for Dirichlet distribution for class probability * @throws Exception if there is a problem generating the prediction */ public double[] countsForInstance(Instance instance) throws Exception { double[] fCounts = new double[m_NumClasses]; for (int iClass = 0; iClass < m_NumClasses; iClass++) { fCounts[iClass] = 0.0; } for (int iClass = 0; iClass < m_NumClasses; iClass++) { double fCount = 0; for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { double iCPT = 0; for (int iParent = 0; iParent < m_ParentSets[iAttribute] .getNrOfParents(); iParent++) { int nParent = m_ParentSets[iAttribute].getParent(iParent); if (nParent == m_Instances.classIndex()) { iCPT = iCPT * m_NumClasses + iClass; } else { iCPT = iCPT * m_Instances.attribute(nParent).numValues() + instance.value(nParent); } } if (iAttribute == m_Instances.classIndex()) { fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]) .getCount(iClass); } else { fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]) .getCount(instance.value(iAttribute)); } } fCounts[iClass] += fCount; } return fCounts; } // countsForInstance /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ @Override public Enumeration




© 2015 - 2025 Weber Informatics LLC | Privacy Policy