All Downloads are FREE. Search and download functionalities are using the official Maven repository.

aima.core.probability.bayes.exact.EnumerationAsk Maven / Gradle / Ivy

package aima.core.probability.bayes.exact;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import aima.core.probability.CategoricalDistribution;
import aima.core.probability.RandomVariable;
import aima.core.probability.bayes.BayesInference;
import aima.core.probability.bayes.BayesianNetwork;
import aima.core.probability.bayes.FiniteNode;
import aima.core.probability.bayes.Node;
import aima.core.probability.domain.FiniteDomain;
import aima.core.probability.proposition.AssignmentProposition;
import aima.core.probability.util.ProbabilityTable;
import aima.core.util.Util;

/**
 * Artificial Intelligence A Modern Approach (3rd Edition): Figure 14.9, page
 * 525.
*
* *
 * function ENUMERATION-ASK(X, e, bn) returns a distribution over X
 *   inputs: X, the query variable
 *           e, observed values for variables E
 *           bn, a Bayes net with variables {X} ∪ E ∪ Y /* Y = hidden variables //
 *           
 *   Q(X) <- a distribution over X, initially empty
 *   for each value xi of X do
 *       Q(xi) <- ENUMERATE-ALL(bn.VARS, exi)
 *          where exi is e extended with X = xi
 *   return NORMALIZE(Q(X))
 *   
 * ---------------------------------------------------------------------------------------------------
 * 
 * function ENUMERATE-ALL(vars, e) returns a real number
 *   if EMPTY?(vars) then return 1.0
 *   Y <- FIRST(vars)
 *   if Y has value y in e
 *       then return P(y | parents(Y)) * ENUMERATE-ALL(REST(vars), e)
 *       else return ∑y P(y | parents(Y)) * ENUMERATE-ALL(REST(vars), ey)
 *           where ey is e extended with Y = y
 * 
* * Figure 14.9 The enumeration algorithm for answering queries on Bayesian * networks.
*
* Note: The implementation has been extended to handle queries with * multiple variables.
* * @author Ciaran O'Reilly */ public class EnumerationAsk implements BayesInference { public EnumerationAsk() { } // function ENUMERATION-ASK(X, e, bn) returns a distribution over X /** * The ENUMERATION-ASK algorithm in Figure 14.9 evaluates expression trees * (Figure 14.8) using depth-first recursion. * * @param X * the query variables. * @param observedEvidence * observed values for variables E. * @param bn * a Bayes net with variables {X} ∪ E ∪ Y /* Y = hidden * variables // * @return a distribution over the query variables. */ public CategoricalDistribution enumerationAsk(final RandomVariable[] X, final AssignmentProposition[] observedEvidence, final BayesianNetwork bn) { // Q(X) <- a distribution over X, initially empty final ProbabilityTable Q = new ProbabilityTable(X); final ObservedEvidence e = new ObservedEvidence(X, observedEvidence, bn); // for each value xi of X do ProbabilityTable.Iterator di = new ProbabilityTable.Iterator() { int cnt = 0; /** *
			 * Q(xi) <- ENUMERATE-ALL(bn.VARS, exi)
			 *   where exi is e extended with X = xi
			 * 
*/ public void iterate(Map possibleWorld, double probability) { for (int i = 0; i < X.length; i++) { e.setExtendedValue(X[i], possibleWorld.get(X[i])); } Q.setValue(cnt, enumerateAll(bn.getVariablesInTopologicalOrder(), e)); cnt++; } }; Q.iterateOverTable(di); // return NORMALIZE(Q(X)) return Q.normalize(); } // // START-BayesInference public CategoricalDistribution ask(final RandomVariable[] X, final AssignmentProposition[] observedEvidence, final BayesianNetwork bn) { return this.enumerationAsk(X, observedEvidence, bn); } // END-BayesInference // // // PROTECTED METHODS // // function ENUMERATE-ALL(vars, e) returns a real number protected double enumerateAll(List vars, ObservedEvidence e) { // if EMPTY?(vars) then return 1.0 if (0 == vars.size()) { return 1; } // Y <- FIRST(vars) RandomVariable Y = Util.first(vars); // if Y has value y in e if (e.containsValue(Y)) { // then return P(y | parents(Y)) * ENUMERATE-ALL(REST(vars), e) return e.posteriorForParents(Y) * enumerateAll(Util.rest(vars), e); } /** *
		 *  else return ∑y P(y | parents(Y)) * ENUMERATE-ALL(REST(vars), ey)
		 *       where ey is e extended with Y = y
		 * 
*/ double sum = 0; for (Object y : ((FiniteDomain) Y.getDomain()).getPossibleValues()) { e.setExtendedValue(Y, y); sum += e.posteriorForParents(Y) * enumerateAll(Util.rest(vars), e); } return sum; } protected class ObservedEvidence { private BayesianNetwork bn = null; private Object[] extendedValues = null; private int hiddenStart = 0; private int extendedIdx = 0; private RandomVariable[] var = null; private Map varIdxs = new HashMap(); public ObservedEvidence(RandomVariable[] queryVariables, AssignmentProposition[] e, BayesianNetwork bn) { this.bn = bn; int maxSize = bn.getVariablesInTopologicalOrder().size(); extendedValues = new Object[maxSize]; var = new RandomVariable[maxSize]; // query variables go first int idx = 0; for (int i = 0; i < queryVariables.length; i++) { var[idx] = queryVariables[i]; varIdxs.put(var[idx], idx); idx++; } // initial evidence variables go next for (int i = 0; i < e.length; i++) { var[idx] = e[i].getTermVariable(); varIdxs.put(var[idx], idx); extendedValues[idx] = e[i].getValue(); idx++; } extendedIdx = idx - 1; hiddenStart = idx; // the remaining slots are left open for the hidden variables for (RandomVariable rv : bn.getVariablesInTopologicalOrder()) { if (!varIdxs.containsKey(rv)) { var[idx] = rv; varIdxs.put(var[idx], idx); idx++; } } } public void setExtendedValue(RandomVariable rv, Object value) { int idx = varIdxs.get(rv); extendedValues[idx] = value; if (idx >= hiddenStart) { extendedIdx = idx; } else { extendedIdx = hiddenStart - 1; } } public boolean containsValue(RandomVariable rv) { return varIdxs.get(rv) <= extendedIdx; } public double posteriorForParents(RandomVariable rv) { Node n = bn.getNode(rv); if (!(n instanceof FiniteNode)) { throw new IllegalArgumentException( "Enumeration-Ask only works with finite Nodes."); } FiniteNode fn = (FiniteNode) n; Object[] vals = new Object[1 + fn.getParents().size()]; int idx = 0; for (Node pn : n.getParents()) { vals[idx] = extendedValues[varIdxs.get(pn.getRandomVariable())]; idx++; } vals[idx] = extendedValues[varIdxs.get(rv)]; return fn.getCPT().getValue(vals); } } // // PRIVATE METHODS // }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy