All Downloads are FREE. Search and download functionalities are using the official Maven repository.

aima.core.probability.bayes.impl.CPT Maven / Gradle / Ivy

Go to download

AIMA-Java Core Algorithms from the book Artificial Intelligence a Modern Approach 3rd Ed.

The newest version!
package aima.core.probability.bayes.impl;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import aima.core.probability.CategoricalDistribution;
import aima.core.probability.Factor;
import aima.core.probability.ProbabilityModel;
import aima.core.probability.RandomVariable;
import aima.core.probability.bayes.ConditionalProbabilityTable;
import aima.core.probability.domain.FiniteDomain;
import aima.core.probability.proposition.AssignmentProposition;
import aima.core.probability.util.ProbUtil;
import aima.core.probability.util.ProbabilityTable;

/**
 * Default implementation of the ConditionalProbabilityTable interface.
 * 
 * @author Ciaran O'Reilly
 * 
 */
public class CPT implements ConditionalProbabilityTable {
	private RandomVariable on = null;
	private LinkedHashSet parents = new LinkedHashSet();
	private ProbabilityTable table = null;
	private List onDomain = new ArrayList();

	public CPT(RandomVariable on, double[] values,
			RandomVariable... conditionedOn) {
		this.on = on;
		if (null == conditionedOn) {
			conditionedOn = new RandomVariable[0];
		}
		RandomVariable[] tableVars = new RandomVariable[conditionedOn.length + 1];
		for (int i = 0; i < conditionedOn.length; i++) {
			tableVars[i] = conditionedOn[i];
			parents.add(conditionedOn[i]);
		}
		tableVars[conditionedOn.length] = on;
		table = new ProbabilityTable(values, tableVars);
		onDomain.addAll(((FiniteDomain) on.getDomain()).getPossibleValues());

		checkEachRowTotalsOne();
	}

	public double probabilityFor(final Object... values) {
		return table.getValue(values);
	}

	//
	// START-ConditionalProbabilityDistribution

	@Override
	public RandomVariable getOn() {
		return on;
	}

	@Override
	public Set getParents() {
		return parents;
	}

	@Override
	public Set getFor() {
		return table.getFor();
	}

	@Override
	public boolean contains(RandomVariable rv) {
		return table.contains(rv);
	}

	@Override
	public double getValue(Object... eventValues) {
		return table.getValue(eventValues);
	}

	@Override
	public double getValue(AssignmentProposition... eventValues) {
		return table.getValue(eventValues);
	}

	@Override
	public Object getSample(double probabilityChoice, Object... parentValues) {
		return ProbUtil.sample(probabilityChoice, on,
				getConditioningCase(parentValues).getValues());
	}

	@Override
	public Object getSample(double probabilityChoice,
			AssignmentProposition... parentValues) {
		return ProbUtil.sample(probabilityChoice, on,
				getConditioningCase(parentValues).getValues());
	}

	// END-ConditionalProbabilityDistribution
	//

	//
	// START-ConditionalProbabilityTable
	@Override
	public CategoricalDistribution getConditioningCase(Object... parentValues) {
		if (parentValues.length != parents.size()) {
			throw new IllegalArgumentException(
					"The number of parent value arguments ["
							+ parentValues.length
							+ "] is not equal to the number of parents ["
							+ parents.size() + "] for this CPT.");
		}
		AssignmentProposition[] aps = new AssignmentProposition[parentValues.length];
		int idx = 0;
		for (RandomVariable parentRV : parents) {
			aps[idx] = new AssignmentProposition(parentRV, parentValues[idx]);
			idx++;
		}

		return getConditioningCase(aps);
	}

	@Override
	public CategoricalDistribution getConditioningCase(
			AssignmentProposition... parentValues) {
		if (parentValues.length != parents.size()) {
			throw new IllegalArgumentException(
					"The number of parent value arguments ["
							+ parentValues.length
							+ "] is not equal to the number of parents ["
							+ parents.size() + "] for this CPT.");
		}
		final ProbabilityTable cc = new ProbabilityTable(getOn());
		ProbabilityTable.Iterator pti = new ProbabilityTable.Iterator() {
			private int idx = 0;

			@Override
			public void iterate(Map possibleAssignment,
					double probability) {
				cc.getValues()[idx] = probability;
				idx++;
			}
		};
		table.iterateOverTable(pti, parentValues);

		return cc;
	}

	public Factor getFactorFor(final AssignmentProposition... evidence) {
		Set fofVars = new LinkedHashSet(
				table.getFor());
		for (AssignmentProposition ap : evidence) {
			fofVars.remove(ap.getTermVariable());
		}
		final ProbabilityTable fof = new ProbabilityTable(fofVars);
		// Otherwise need to iterate through the table for the
		// non evidence variables.
		final Object[] termValues = new Object[fofVars.size()];
		ProbabilityTable.Iterator di = new ProbabilityTable.Iterator() {
			public void iterate(Map possibleWorld,
					double probability) {
				if (0 == termValues.length) {
					fof.getValues()[0] += probability;
				} else {
					int i = 0;
					for (RandomVariable rv : fof.getFor()) {
						termValues[i] = possibleWorld.get(rv);
						i++;
					}
					fof.getValues()[fof.getIndex(termValues)] += probability;
				}
			}
		};
		table.iterateOverTable(di, evidence);

		return fof;
	}

	// END-ConditionalProbabilityTable
	//

	//
	// PRIVATE METHODS
	//
	private void checkEachRowTotalsOne() {
		ProbabilityTable.Iterator di = new ProbabilityTable.Iterator() {
			private int rowSize = onDomain.size();
			private int iterateCnt = 0;
			private double rowProb = 0;

			public void iterate(Map possibleWorld,
					double probability) {
				iterateCnt++;
				rowProb += probability;
				if (iterateCnt % rowSize == 0) {
					if (Math.abs(1 - rowProb) > ProbabilityModel.DEFAULT_ROUNDING_THRESHOLD) {
						throw new IllegalArgumentException("Row "
								+ (iterateCnt / rowSize)
								+ " of CPT does not sum to 1.0.");
					}
					rowProb = 0;
				}
			}
		};

		table.iterateOverTable(di);
	}
}