All Downloads are FREE. Search and download functionalities are using the official Maven repository.

aima.core.probability.util.ProbabilityTable Maven / Gradle / Ivy

Go to download

AIMA-Java Core Algorithms from the book Artificial Intelligence a Modern Approach 3rd Ed.

The newest version!
package aima.core.probability.util;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

import aima.core.probability.CategoricalDistribution;
import aima.core.probability.Factor;
import aima.core.probability.RandomVariable;
import aima.core.probability.domain.FiniteDomain;
import aima.core.probability.proposition.AssignmentProposition;
import aima.core.util.SetOps;
import aima.core.util.math.MixedRadixNumber;

/**
 * A Utility Class for associating values with a set of finite Random Variables.
 * This is also the default implementation of the CategoricalDistribution and
 * Factor interfaces (as they are essentially dependent on the same underlying
 * data structures).
 * 
 * @author Ciaran O'Reilly
 */
public class ProbabilityTable implements CategoricalDistribution, Factor {
	private double[] values = null;
	//
	private Map randomVarInfo = new LinkedHashMap();
	private int[] radices = null;
	private MixedRadixNumber queryMRN = null;
	//
	private String toString = null;
	private double sum = -1;

	/**
	 * Interface to be implemented by an object/algorithm that wishes to iterate
	 * over the possible assignments for the random variables comprising this
	 * table.
	 * 
	 * @see ProbabilityTable#iterateOverTable(Iterator)
	 * @see ProbabilityTable#iterateOverTable(Iterator,
	 *      AssignmentProposition...)
	 */
	public interface Iterator {
		/**
		 * Called for each possible assignment for the Random Variables
		 * comprising this ProbabilityTable.
		 * 
		 * @param possibleAssignment
		 *            a possible assignment, ω, of variable/value pairs.
		 * @param probability
		 *            the probability associated with ω
		 */
		void iterate(Map possibleAssignment,
				double probability);
	}

	public ProbabilityTable(Collection vars) {
		this(vars.toArray(new RandomVariable[vars.size()]));
	}

	public ProbabilityTable(RandomVariable... vars) {
		this(new double[ProbUtil.expectedSizeOfProbabilityTable(vars)], vars);
	}

	public ProbabilityTable(double[] vals, RandomVariable... vars) {
		if (null == vals) {
			throw new IllegalArgumentException("Values must be specified");
		}
		if (vals.length != ProbUtil.expectedSizeOfProbabilityTable(vars)) {
			throw new IllegalArgumentException("ProbabilityTable of length "
					+ vals.length + " is not the correct size, should be "
					+ ProbUtil.expectedSizeOfProbabilityTable(vars)
					+ " in order to represent all possible combinations.");
		}
		if (null != vars) {
			for (RandomVariable rv : vars) {
				// Track index information relevant to each variable.
				randomVarInfo.put(rv, new RVInfo(rv));
			}
		}

		values = new double[vals.length];
		System.arraycopy(vals, 0, values, 0, vals.length);

		radices = createRadixs(randomVarInfo);

		if (radices.length > 0) {
			queryMRN = new MixedRadixNumber(0, radices);
		}
	}

	public int size() {
		return values.length;
	}

	//
	// START-ProbabilityDistribution
	@Override
	public Set getFor() {
		return randomVarInfo.keySet();
	}

	@Override
	public boolean contains(RandomVariable rv) {
		return randomVarInfo.keySet().contains(rv);
	}

	@Override
	public double getValue(Object... assignments) {
		return values[getIndex(assignments)];
	}

	@Override
	public double getValue(AssignmentProposition... assignments) {
		if (assignments.length != randomVarInfo.size()) {
			throw new IllegalArgumentException(
					"Assignments passed in is not the same size as variables making up probability table.");
		}
		int[] radixValues = new int[assignments.length];
		for (AssignmentProposition ap : assignments) {
			RVInfo rvInfo = randomVarInfo.get(ap.getTermVariable());
			if (null == rvInfo) {
				throw new IllegalArgumentException(
						"Assignment passed for a variable that is not part of this probability table:"
								+ ap.getTermVariable());
			}
			radixValues[rvInfo.getRadixIdx()] = rvInfo.getIdxForDomain(ap
					.getValue());
		}
		return values[(int) queryMRN.getCurrentValueFor(radixValues)];
	}

	// END-ProbabilityDistribution
	//

	//
	// START-CategoricalDistribution
	public double[] getValues() {
		return values;
	}

	@Override
	public void setValue(int idx, double value) {
		values[idx] = value;
		reinitLazyValues();
	}

	@Override
	public double getSum() {
		if (-1 == sum) {
			sum = 0;
			for (int i = 0; i < values.length; i++) {
				sum += values[i];
			}
		}
		return sum;
	}

	@Override
	public ProbabilityTable normalize() {
		double s = getSum();
		if (s != 0 && s != 1.0) {
			for (int i = 0; i < values.length; i++) {
				values[i] = values[i] / s;
			}
			reinitLazyValues();
		}
		return this;
	}

	@Override
	public int getIndex(Object... assignments) {
		if (assignments.length != randomVarInfo.size()) {
			throw new IllegalArgumentException(
					"Assignments passed in is not the same size as variables making up the table.");
		}
		int[] radixValues = new int[assignments.length];
		int i = 0;
		for (RVInfo rvInfo : randomVarInfo.values()) {
			radixValues[rvInfo.getRadixIdx()] = rvInfo
					.getIdxForDomain(assignments[i]);
			i++;
		}

		return (int) queryMRN.getCurrentValueFor(radixValues);
	}

	@Override
	public CategoricalDistribution marginal(RandomVariable... vars) {
		return sumOut(vars);
	}

	@Override
	public CategoricalDistribution divideBy(CategoricalDistribution divisor) {
		return divideBy((ProbabilityTable) divisor);
	}

	@Override
	public CategoricalDistribution multiplyBy(CategoricalDistribution multiplier) {
		return pointwiseProduct((ProbabilityTable) multiplier);
	}

	@Override
	public CategoricalDistribution multiplyByPOS(
			CategoricalDistribution multiplier, RandomVariable... prodVarOrder) {
		return pointwiseProductPOS((ProbabilityTable) multiplier, prodVarOrder);
	}

	@Override
	public void iterateOver(CategoricalDistribution.Iterator cdi) {
		iterateOverTable(new CategoricalDistributionIteratorAdapter(cdi));
	}

	@Override
	public void iterateOver(CategoricalDistribution.Iterator cdi,
			AssignmentProposition... fixedValues) {
		iterateOverTable(new CategoricalDistributionIteratorAdapter(cdi),
				fixedValues);
	}

	// END-CategoricalDistribution
	//

	//
	// START-Factor
	@Override
	public Set getArgumentVariables() {
		return randomVarInfo.keySet();
	}

	@Override
	public ProbabilityTable sumOut(RandomVariable... vars) {
		Set soutVars = new LinkedHashSet(
				this.randomVarInfo.keySet());
		for (RandomVariable rv : vars) {
			soutVars.remove(rv);
		}
		final ProbabilityTable summedOut = new ProbabilityTable(soutVars);
		if (1 == summedOut.getValues().length) {
			summedOut.getValues()[0] = getSum();
		} else {
			// Otherwise need to iterate through this distribution
			// to calculate the summed out distribution.
			final Object[] termValues = new Object[summedOut.randomVarInfo
					.size()];
			ProbabilityTable.Iterator di = new ProbabilityTable.Iterator() {
				public void iterate(Map possibleWorld,
						double probability) {

					int i = 0;
					for (RandomVariable rv : summedOut.randomVarInfo.keySet()) {
						termValues[i] = possibleWorld.get(rv);
						i++;
					}
					summedOut.getValues()[summedOut.getIndex(termValues)] += probability;
				}
			};
			iterateOverTable(di);
		}

		return summedOut;
	}

	@Override
	public Factor pointwiseProduct(Factor multiplier) {
		return pointwiseProduct((ProbabilityTable) multiplier);
	}

	@Override
	public Factor pointwiseProductPOS(Factor multiplier,
			RandomVariable... prodVarOrder) {
		return pointwiseProductPOS((ProbabilityTable) multiplier, prodVarOrder);
	}

	@Override
	public void iterateOver(Factor.Iterator fi) {
		iterateOverTable(new FactorIteratorAdapter(fi));
	}

	@Override
	public void iterateOver(Factor.Iterator fi,
			AssignmentProposition... fixedValues) {
		iterateOverTable(new FactorIteratorAdapter(fi), fixedValues);
	}

	// END-Factor
	//

	/**
	 * Iterate over all the possible value assignments for the Random Variables
	 * comprising this ProbabilityTable.
	 * 
	 * @param pti
	 *            the ProbabilityTable Iterator to iterate.
	 */
	public void iterateOverTable(Iterator pti) {
		Map possibleWorld = new LinkedHashMap();
		MixedRadixNumber mrn = new MixedRadixNumber(0, radices);
		do {
			for (RVInfo rvInfo : randomVarInfo.values()) {
				possibleWorld.put(rvInfo.getVariable(), rvInfo
						.getDomainValueAt(mrn.getCurrentNumeralValue(rvInfo
								.getRadixIdx())));
			}
			pti.iterate(possibleWorld, values[mrn.intValue()]);

		} while (mrn.increment());
	}

	/**
	 * Iterate over all possible values assignments for the Random Variables
	 * comprising this ProbabilityTable that are not in the fixed set of values.
	 * This allows you to iterate over a subset of possible combinations.
	 * 
	 * @param pti
	 *            the ProbabilityTable Iterator to iterate
	 * @param fixedValues
	 *            Fixed values for a subset of the Random Variables comprising
	 *            this Probability Table.
	 */
	public void iterateOverTable(Iterator pti,
			AssignmentProposition... fixedValues) {
		Map possibleWorld = new LinkedHashMap();
		MixedRadixNumber tableMRN = new MixedRadixNumber(0, radices);
		int[] tableRadixValues = new int[radices.length];

		// Assert that the Random Variables for the fixed values
		// are part of this probability table and assign
		// all the fixed values to the possible world.
		for (AssignmentProposition ap : fixedValues) {
			if (!randomVarInfo.containsKey(ap.getTermVariable())) {
				throw new IllegalArgumentException("Assignment proposition ["
						+ ap + "] does not belong to this probability table.");
			}
			possibleWorld.put(ap.getTermVariable(), ap.getValue());
			RVInfo fixedRVI = randomVarInfo.get(ap.getTermVariable());
			tableRadixValues[fixedRVI.getRadixIdx()] = fixedRVI
					.getIdxForDomain(ap.getValue());
		}
		// If have assignments for all the random variables
		// in this probability table
		if (fixedValues.length == randomVarInfo.size()) {
			// Then only 1 iteration call is required.
			pti.iterate(possibleWorld, getValue(fixedValues));
		} else {
			// Else iterate over the non-fixed values
			Set freeVariables = SetOps.difference(
					this.randomVarInfo.keySet(), possibleWorld.keySet());
			Map freeVarInfo = new LinkedHashMap();
			// Remove the fixed Variables
			for (RandomVariable fv : freeVariables) {
				freeVarInfo.put(fv, new RVInfo(fv));
			}
			int[] freeRadixValues = createRadixs(freeVarInfo);
			MixedRadixNumber freeMRN = new MixedRadixNumber(0, freeRadixValues);
			Object fval = null;
			// Iterate through all combinations of the free variables
			do {
				// Put the current assignments for the free variables
				// into the possible world and update
				// the current index in the table MRN
				for (RVInfo freeRVI : freeVarInfo.values()) {
					fval = freeRVI.getDomainValueAt(freeMRN
							.getCurrentNumeralValue(freeRVI.getRadixIdx()));
					possibleWorld.put(freeRVI.getVariable(), fval);

					tableRadixValues[randomVarInfo.get(freeRVI.getVariable())
							.getRadixIdx()] = freeRVI.getIdxForDomain(fval);
				}
				pti.iterate(possibleWorld, values[(int) tableMRN
						.getCurrentValueFor(tableRadixValues)]);

			} while (freeMRN.increment());
		}
	}

	public ProbabilityTable divideBy(ProbabilityTable divisor) {
		if (!randomVarInfo.keySet().containsAll(divisor.randomVarInfo.keySet())) {
			throw new IllegalArgumentException(
					"Divisor must be a subset of the dividend.");
		}

		final ProbabilityTable quotient = new ProbabilityTable(randomVarInfo
				.keySet());

		if (1 == divisor.getValues().length) {
			double d = divisor.getValues()[0];
			for (int i = 0; i < quotient.getValues().length; i++) {
				if (0 == d) {
					quotient.getValues()[i] = 0;
				} else {
					quotient.getValues()[i] = getValues()[i] / d;
				}
			}
		} else {
			Set dividendDivisorDiff = SetOps
					.difference(this.randomVarInfo.keySet(),
							divisor.randomVarInfo.keySet());
			Map tdiff = null;
			MixedRadixNumber tdMRN = null;
			if (dividendDivisorDiff.size() > 0) {
				tdiff = new LinkedHashMap();
				for (RandomVariable rv : dividendDivisorDiff) {
					tdiff.put(rv, new RVInfo(rv));
				}
				tdMRN = new MixedRadixNumber(0, createRadixs(tdiff));
			}
			final Map diff = tdiff;
			final MixedRadixNumber dMRN = tdMRN;
			final int[] qRVs = new int[quotient.radices.length];
			final MixedRadixNumber qMRN = new MixedRadixNumber(0,
					quotient.radices);
			ProbabilityTable.Iterator divisorIterator = new ProbabilityTable.Iterator() {
				public void iterate(Map possibleWorld,
						double probability) {
					for (RandomVariable rv : possibleWorld.keySet()) {
						RVInfo rvInfo = quotient.randomVarInfo.get(rv);
						qRVs[rvInfo.getRadixIdx()] = rvInfo
								.getIdxForDomain(possibleWorld.get(rv));
					}
					if (null != diff) {
						// Start from 0 off the diff
						dMRN.setCurrentValueFor(new int[diff.size()]);
						do {
							for (RandomVariable rv : diff.keySet()) {
								RVInfo drvInfo = diff.get(rv);
								RVInfo qrvInfo = quotient.randomVarInfo.get(rv);
								qRVs[qrvInfo.getRadixIdx()] = dMRN
										.getCurrentNumeralValue(drvInfo
												.getRadixIdx());
							}
							updateQuotient(probability);
						} while (dMRN.increment());
					} else {
						updateQuotient(probability);
					}
				}

				//
				//
				private void updateQuotient(double probability) {
					int offset = (int) qMRN.getCurrentValueFor(qRVs);
					if (0 == probability) {
						quotient.getValues()[offset] = 0;
					} else {
						quotient.getValues()[offset] += getValues()[offset]
								/ probability;
					}
				}
			};

			divisor.iterateOverTable(divisorIterator);
		}

		return quotient;
	}

	public ProbabilityTable pointwiseProduct(final ProbabilityTable multiplier) {
		Set prodVars = SetOps.union(randomVarInfo.keySet(),
				multiplier.randomVarInfo.keySet());
		return pointwiseProductPOS(multiplier, prodVars
				.toArray(new RandomVariable[prodVars.size()]));
	}

	public ProbabilityTable pointwiseProductPOS(
			final ProbabilityTable multiplier, RandomVariable... prodVarOrder) {
		final ProbabilityTable product = new ProbabilityTable(prodVarOrder);
		if (!product.randomVarInfo.keySet().equals(
				SetOps.union(randomVarInfo.keySet(), multiplier.randomVarInfo
						.keySet()))) {
			throw new IllegalArgumentException(
					"Specified list deatailing order of mulitplier is inconsistent.");
		}

		// If no variables in the product
		if (1 == product.getValues().length) {
			product.getValues()[0] = getValues()[0] * multiplier.getValues()[0];
		} else {
			// Otherwise need to iterate through the product
			// to calculate its values based on the terms.
			final Object[] term1Values = new Object[randomVarInfo.size()];
			final Object[] term2Values = new Object[multiplier.randomVarInfo
					.size()];
			ProbabilityTable.Iterator di = new ProbabilityTable.Iterator() {
				private int idx = 0;

				public void iterate(Map possibleWorld,
						double probability) {
					int term1Idx = termIdx(term1Values, ProbabilityTable.this,
							possibleWorld);
					int term2Idx = termIdx(term2Values, multiplier,
							possibleWorld);

					product.getValues()[idx] = getValues()[term1Idx]
							* multiplier.getValues()[term2Idx];

					idx++;
				}

				private int termIdx(Object[] termValues, ProbabilityTable d,
						Map possibleWorld) {
					if (0 == termValues.length) {
						// The term has no variables so always position 0.
						return 0;
					}

					int i = 0;
					for (RandomVariable rv : d.randomVarInfo.keySet()) {
						termValues[i] = possibleWorld.get(rv);
						i++;
					}

					return d.getIndex(termValues);
				}
			};
			product.iterateOverTable(di);
		}

		return product;
	}

	@Override
	public String toString() {
		if (null == toString) {
			StringBuilder sb = new StringBuilder();
			sb.append("<");
			for (int i = 0; i < values.length; i++) {
				if (i > 0) {
					sb.append(", ");
				}
				sb.append(values[i]);
			}
			sb.append(">");

			toString = sb.toString();
		}
		return toString;
	}

	//
	// PRIVATE METHODS
	//
	private void reinitLazyValues() {
		sum = -1;
		toString = null;
	}

	private int[] createRadixs(Map mapRtoInfo) {
		int[] r = new int[mapRtoInfo.size()];
		// Read in reverse order so that the enumeration
		// through the distributions is of the following
		// order using a MixedRadixNumber, e.g. for two Booleans:
		// X Y
		// true true
		// true false
		// false true
		// false false
		// which corresponds with how displayed in book.
		int x = mapRtoInfo.size() - 1;
		for (RVInfo rvInfo : mapRtoInfo.values()) {
			r[x] = rvInfo.getDomainSize();
			rvInfo.setRadixIdx(x);
			x--;
		}
		return r;
	}

	private class RVInfo {
		private RandomVariable variable;
		private FiniteDomain varDomain;
		private int radixIdx = 0;

		public RVInfo(RandomVariable rv) {
			variable = rv;
			varDomain = (FiniteDomain) variable.getDomain();
		}

		public RandomVariable getVariable() {
			return variable;
		}

		public int getDomainSize() {
			return varDomain.size();
		}

		public int getIdxForDomain(Object value) {
			return varDomain.getOffset(value);
		}

		public Object getDomainValueAt(int idx) {
			return varDomain.getValueAt(idx);
		}

		public void setRadixIdx(int idx) {
			radixIdx = idx;
		}

		public int getRadixIdx() {
			return radixIdx;
		}
	}

	private class CategoricalDistributionIteratorAdapter implements Iterator {
		private CategoricalDistribution.Iterator cdi = null;

		public CategoricalDistributionIteratorAdapter(
				CategoricalDistribution.Iterator cdi) {
			this.cdi = cdi;
		}

		public void iterate(Map possibleAssignment,
				double probability) {
			cdi.iterate(possibleAssignment, probability);
		}
	}

	private class FactorIteratorAdapter implements Iterator {
		private Factor.Iterator fi = null;

		public FactorIteratorAdapter(Factor.Iterator fi) {
			this.fi = fi;
		}

		public void iterate(Map possibleAssignment,
				double probability) {
			fi.iterate(possibleAssignment, probability);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy