All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.kave.repackaged.jayes.inference.junctionTree.JunctionTreeAlgorithm Maven / Gradle / Ivy

The newest version!
/**
 * Copyright (c) 2011 Michael Kutschke. All rights reserved. This program and the accompanying
 * materials are made available under the terms of the Eclipse Public License v1.0 which accompanies
 * this distribution, and is available at http://www.eclipse.org/legal/epl-v10.html Contributors:
 * Michael Kutschke - initial API and implementation.
 */
package cc.kave.repackaged.jayes.inference.junctionTree;

import static cc.kave.repackaged.jayes.util.Pair.newPair;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import cc.kave.repackaged.jayes.BayesNet;
import cc.kave.repackaged.jayes.BayesNode;
import cc.kave.repackaged.jayes.factor.AbstractFactor;
import cc.kave.repackaged.jayes.factor.arraywrapper.DoubleArrayWrapper;
import cc.kave.repackaged.jayes.factor.arraywrapper.IArrayWrapper;
import cc.kave.repackaged.jayes.inference.AbstractInferer;
import cc.kave.repackaged.jayes.internal.util.ArrayUtils;
import cc.kave.repackaged.jayes.util.Graph;
import cc.kave.repackaged.jayes.util.MathUtils;
import cc.kave.repackaged.jayes.util.NumericalInstabilityException;
import cc.kave.repackaged.jayes.util.Pair;
import cc.kave.repackaged.jayes.util.Graph.Edge;
import cc.kave.repackaged.jayes.util.sharing.CanonicalArrayWrapperManager;
import cc.kave.repackaged.jayes.util.sharing.CanonicalIntArrayManager;
import cc.kave.repackaged.jayes.util.triangulation.MinFillIn;

public class JunctionTreeAlgorithm extends AbstractInferer {

    protected Map sepSets;
    protected Graph junctionTree;
    protected AbstractFactor[] nodePotentials;
    // need IdentityHashmap here because an Edge and
    // it's backward Edge are considered equal
    // (which is also needed for simplicity)
    protected IdentityHashMap preparedMultiplications;

    // mapping from variables to clusters that contain them
    protected int[][] concernedClusters;
    protected AbstractFactor[] queryFactors;
    protected int[][] preparedQueries;
    protected boolean[] isBeliefValid;
    protected List> initializations;

    protected int[][] queryFactorReverseMapping;

    // used for computing evidence collection skip
    protected Set clustersHavingEvidence;
    protected boolean[] isObserved;

    protected double[] scratchpad;

    protected JunctionTreeBuilder junctionTreeBuilder = JunctionTreeBuilder.forHeuristic(new MinFillIn());

    public void setJunctionTreeBuilder(JunctionTreeBuilder bldr) {
        this.junctionTreeBuilder = bldr;
    }

    @Override
    public double[] getBeliefs(final BayesNode node) {
        if (!beliefsValid) {
            beliefsValid = true;
            updateBeliefs();
        }
        final int nodeId = node.getId();
        if (!isBeliefValid[nodeId]) {
            isBeliefValid[nodeId] = true;
            if (!evidence.containsKey(node)) {
                validateBelief(nodeId);
            } else {
                Arrays.fill(beliefs[nodeId], 0);
                beliefs[nodeId][node.getOutcomeIndex(evidence.get(node))] = 1;
            }
        }
        return super.getBeliefs(node);
    }

    private void validateBelief(final int nodeId) {
        final AbstractFactor f = queryFactors[nodeId];
        // TODO change beliefs to ArrayWrappers
        f.sumPrepared(new DoubleArrayWrapper(beliefs[nodeId]), preparedQueries[nodeId]);
        if (f.isLogScale()) {
            MathUtils.exp(beliefs[nodeId]);
        }
        try {
            beliefs[nodeId] = MathUtils.normalize(beliefs[nodeId]);
        } catch (final IllegalArgumentException exception) {
            throw new NumericalInstabilityException("Numerical instability detected for evidence: " + evidence
                    + " and node : " + nodeId
                    + ", consider using logarithmic scale computation (configurable in FactorFactory)", exception);
        }
    }

    @Override
    protected void updateBeliefs() {
        Arrays.fill(isBeliefValid, false);
        doUpdateBeliefs();
    }

    private void doUpdateBeliefs() {

        incorporateAllEvidence();
        int propagationRoot = findPropagationRoot();

        replayFactorInitializations();
        collectEvidence(propagationRoot, skipCollection(propagationRoot));
        distributeEvidence(propagationRoot, skipDistribution(propagationRoot));
    }

    private void replayFactorInitializations() {
        for (final Pair init : initializations) {
            init.getFirst().copyValues(init.getSecond());
        }
    }

    private void incorporateAllEvidence() {
        for (Pair init : initializations) {
            init.getFirst().resetSelections();
        }

        clustersHavingEvidence.clear();
        Arrays.fill(isObserved, false);
        for (BayesNode n : evidence.keySet()) {
            incorporateEvidence(n);
        }

    }

    private void incorporateEvidence(final BayesNode node) {
        int n = node.getId();
        isObserved[n] = true;
        // get evidence to all concerned factors (includes home cluster)
        for (final Integer concernedCluster : concernedClusters[n]) {
            nodePotentials[concernedCluster].select(n, node.getOutcomeIndex(evidence.get(node)));
            clustersHavingEvidence.add(concernedCluster);
        }
    }

    private int findPropagationRoot() {
        int propagationRoot = 0;
        for (BayesNode n : evidence.keySet()) {
            propagationRoot = concernedClusters[n.getId()][0];
        }
        return propagationRoot;
    }

    /**
     * checks which nodes need not be processed during collectEvidence (because of preprocessing). These are those nodes
     * without evidence which are leaves or which only have non-evidence descendants
     * 
     * @param root
     *            the node to start the check from
     * @return a set of the nodes not needing a call of collectEvidence
     */
    private Set skipCollection(final int root) {
        final Set skipped = new HashSet(nodePotentials.length);
        recursiveSkipCollection(root, new HashSet(nodePotentials.length), skipped);
        return skipped;
    }

    private void recursiveSkipCollection(final int node, final Set visited, final Set skipped) {
        visited.add(node);
        boolean areAllDescendantsSkipped = true;
        for (final Edge e : junctionTree.getIncidentEdges(node)) {
            if (!visited.contains(e.getSecond())) {
                recursiveSkipCollection(e.getSecond(), visited, skipped);
                if (!skipped.contains(e.getSecond())) {
                    areAllDescendantsSkipped = false;
                }
            }
        }
        if (areAllDescendantsSkipped && !clustersHavingEvidence.contains(node)) {
            skipped.add(node);
        }

    }

    /**
     * checks which nodes do not need to be visited during evidence distribution. These are exactly those nodes which
     * are
     * 
    *
  • not the query factor of a non-evidence variable
  • *
  • AND have no descendants that cannot be skipped
  • *
* * @param distNode * @return */ private Set skipDistribution(final int distNode) { final Set skipped = new HashSet(nodePotentials.length); recursiveSkipDistribution(distNode, new HashSet(nodePotentials.length), skipped); return skipped; } private void recursiveSkipDistribution(final int node, final Set visited, final Set skipped) { visited.add(node); boolean areAllDescendantsSkipped = true; for (final Edge e : junctionTree.getIncidentEdges(node)) { if (!visited.contains(e.getSecond())) { recursiveSkipDistribution(e.getSecond(), visited, skipped); if (!skipped.contains(e.getSecond())) { areAllDescendantsSkipped = false; } } } if (areAllDescendantsSkipped && !isQueryFactorOfUnobservedVariable(node)) { skipped.add(node); } } private boolean isQueryFactorOfUnobservedVariable(final int node) { for (int i : queryFactorReverseMapping[node]) { if (!isObserved[i]) { return true; } } return false; } private void collectEvidence(final int cluster, final Set marked) { marked.add(cluster); for (final Edge e : junctionTree.getIncidentEdges(cluster)) { if (!marked.contains(e.getSecond())) { collectEvidence(e.getSecond(), marked); messagePass(e.getBackEdge()); } } } private void distributeEvidence(final int cluster, final Set marked) { marked.add(cluster); for (final Edge e : junctionTree.getIncidentEdges(cluster)) { if (!marked.contains(e.getSecond())) { messagePass(e); distributeEvidence(e.getSecond(), marked); } } } private void messagePass(final Edge sepSetEdge) { final AbstractFactor sepSet = sepSets.get(sepSetEdge); if (!needMessagePass(sepSet)) { return; } final IArrayWrapper newSepValues = sepSet.getValues(); System.arraycopy(newSepValues.toDoubleArray(), 0, scratchpad, 0, newSepValues.length()); final int[] preparedOp = preparedMultiplications.get(sepSetEdge.getBackEdge()); nodePotentials[sepSetEdge.getFirst()].sumPrepared(newSepValues, preparedOp); if (isOnlyFirstLogScale(sepSetEdge)) { MathUtils.exp(newSepValues); } if (areBothEndsLogScale(sepSetEdge)) { MathUtils.secureSubtract(newSepValues.toDoubleArray(), scratchpad, scratchpad); } else { MathUtils.secureDivide(newSepValues.toDoubleArray(), scratchpad, scratchpad); } if (isOnlySecondLogScale(sepSetEdge)) { MathUtils.log(scratchpad); } // TODO scratchpad -> ArrayWrapper nodePotentials[sepSetEdge.getSecond()].multiplyPrepared(new DoubleArrayWrapper(scratchpad), preparedMultiplications.get(sepSetEdge)); } /* * we don't get additional information if all variables in the sepSet are * observed, so skip message pass */ private boolean needMessagePass(final AbstractFactor sepSet) { for (final int var : sepSet.getDimensionIDs()) { if (!isObserved[var]) { return true; } } return false; } private boolean isOnlyFirstLogScale(final Edge edge) { return nodePotentials[edge.getFirst()].isLogScale() && !nodePotentials[edge.getSecond()].isLogScale(); } private boolean isOnlySecondLogScale(final Edge edge) { return !nodePotentials[edge.getFirst()].isLogScale() && nodePotentials[edge.getSecond()].isLogScale(); } @Override public void setNetwork(final BayesNet net) { super.setNetwork(net); initializeFields(net.getNodes().size()); JunctionTree jtree = buildJunctionTree(net); int[] homeClusters = computeHomeClusters(net, jtree.getClusters()); initializeClusterFactors(net, jtree.getClusters(), homeClusters); initializeSepsetFactors(jtree.getSepSets()); determineConcernedClusters(); setQueryFactors(); initializePotentialValues(); multiplyCPTsIntoPotentials(net, homeClusters); prepareMultiplications(); prepareScratch(); invokeInitialBeliefUpdate(); storePotentialValues(); } @SuppressWarnings("unchecked") private void determineConcernedClusters() { concernedClusters = new int[queryFactors.length][]; List[] temp = new List[concernedClusters.length]; for (int i = 0; i < temp.length; i++) { temp[i] = new ArrayList(); } for (int i = 0; i < nodePotentials.length; i++) { int[] dimensionIDs = nodePotentials[i].getDimensionIDs(); for (final int var : dimensionIDs) { temp[var].add(i); } } for (int i = 0; i < temp.length; i++) { concernedClusters[i] = ArrayUtils.toIntArray(temp[i]); } } private void initializeFields(int numNodes) { isBeliefValid = new boolean[beliefs.length]; Arrays.fill(isBeliefValid, false); queryFactors = new AbstractFactor[numNodes]; preparedQueries = new int[numNodes][]; sepSets = new HashMap(); preparedMultiplications = new IdentityHashMap(); initializations = new ArrayList>(); clustersHavingEvidence = new HashSet(); isObserved = new boolean[numNodes]; } private JunctionTree buildJunctionTree(BayesNet net) { final JunctionTree jtree = junctionTreeBuilder.buildJunctionTree(net); this.junctionTree = jtree.getGraph(); return jtree; } private int[] computeHomeClusters(BayesNet net, final List> clusters) { int[] homeClusters = new int[net.getNodes().size()]; for (final BayesNode node : net.getNodes()) { final List nodeAndParents = getNodeAndParentIds(node); for (final ListIterator> clusterIt = clusters.listIterator(); clusterIt.hasNext();) { if (clusterIt.next().containsAll(nodeAndParents)) { homeClusters[node.getId()] = clusterIt.nextIndex() - 1; break; } } } return homeClusters; } private List getNodeAndParentIds(final BayesNode n) { final List nodeAndParents = new ArrayList(n.getParents().size() + 1); nodeAndParents.add(n.getId()); for (final BayesNode p : n.getParents()) { nodeAndParents.add(p.getId()); } return nodeAndParents; } private void initializeClusterFactors(BayesNet net, final List> clusters, int[] homeClusters) { nodePotentials = new AbstractFactor[clusters.size()]; Map> multiplicationPartners = findMultiplicationPartners(net, homeClusters); for (final ListIterator> cliqueIt = clusters.listIterator(); cliqueIt.hasNext();) { final List cluster = cliqueIt.next(); int current = cliqueIt.nextIndex() - 1; List multiplicationPartnerList = multiplicationPartners.get(current); final AbstractFactor cliqueFactor = factory.create(cluster, multiplicationPartnerList == null ? Collections.emptyList() : multiplicationPartnerList); nodePotentials[current] = cliqueFactor; } } private Map> findMultiplicationPartners(BayesNet net, int[] homeClusters) { Map> potentialMap = new HashMap>(); for (final BayesNode node : net.getNodes()) { final Integer nodeHome = homeClusters[node.getId()]; if (!potentialMap.containsKey(nodeHome)) { potentialMap.put(nodeHome, new ArrayList()); } potentialMap.get(nodeHome).add(node.getFactor()); } return potentialMap; } private void initializeSepsetFactors(final List>> sepSets) { for (final Pair> sep : sepSets) { this.sepSets.put(sep.getFirst(), factory.create(sep.getSecond(), Collections.emptyList())); } } private void setQueryFactors() { for (int i = 0; i < queryFactors.length; i++) { for (final Integer f : concernedClusters[i]) { final boolean isFirstOrSmallerTable = queryFactors[i] == null || queryFactors[i].getValues().length() > nodePotentials[f].getValues().length(); if (isFirstOrSmallerTable) { queryFactors[i] = nodePotentials[f]; } } } queryFactorReverseMapping = new int[nodePotentials.length][]; for (int i = 0; i < nodePotentials.length; i++) { List queryVars = new ArrayList(); for (int var : nodePotentials[i].getDimensionIDs()) { if (queryFactors[var] == nodePotentials[i]) { queryVars.add(var); } } queryFactorReverseMapping[i] = ArrayUtils.toIntArray(queryVars); } } private void prepareMultiplications() { // compress by combining equal prepared statements, thus saving memory final CanonicalIntArrayManager flyWeight = new CanonicalIntArrayManager(); prepareSepsetMultiplications(flyWeight); prepareQueries(flyWeight); } private void prepareSepsetMultiplications(final CanonicalIntArrayManager flyWeight) { for (int node = 0; node < nodePotentials.length; node++) { for (final Edge e : junctionTree.getIncidentEdges(node)) { final int[] preparedMultiplication = nodePotentials[e.getSecond()] .prepareMultiplication(sepSets.get(e)); preparedMultiplications.put(e, flyWeight.getInstance(preparedMultiplication)); } } } private void prepareQueries(final CanonicalIntArrayManager flyWeight) { for (int i = 0; i < queryFactors.length; i++) { final AbstractFactor beliefFactor = factory.create(Arrays.asList(i), Collections.emptyList()); final int[] preparedQuery = queryFactors[i].prepareMultiplication(beliefFactor); preparedQueries[i] = flyWeight.getInstance(preparedQuery); } } private void prepareScratch() { int maxSize = 0; for (AbstractFactor sepSet : sepSets.values()) { maxSize = Math.max(maxSize, sepSet.getValues().length()); } scratchpad = new double[maxSize]; } private void invokeInitialBeliefUpdate() { collectEvidence(0, new HashSet()); distributeEvidence(0, new HashSet()); } private void initializePotentialValues() { final double ONE_LOG = 0.0; final double ONE = 1.0; for (final AbstractFactor f : nodePotentials) { f.fill(f.isLogScale() ? ONE_LOG : ONE); } for (final Entry sepSet : sepSets.entrySet()) { if (!areBothEndsLogScale(sepSet.getKey())) { // if one part is log-scale, we transform to non-log-scale sepSet.getValue().fill(ONE); } else { sepSet.getValue().fill(ONE_LOG); } } } private void multiplyCPTsIntoPotentials(BayesNet net, int[] homeClusters) { for (final BayesNode node : net.getNodes()) { final AbstractFactor nodeHome = nodePotentials[homeClusters[node.getId()]]; if (nodeHome.isLogScale()) { nodeHome.multiplyCompatibleToLog(node.getFactor()); } else { nodeHome.multiplyCompatible(node.getFactor()); } } } private boolean areBothEndsLogScale(final Edge edge) { return nodePotentials[edge.getFirst()].isLogScale() && nodePotentials[edge.getSecond()].isLogScale(); } private void storePotentialValues() { CanonicalArrayWrapperManager flyweight = new CanonicalArrayWrapperManager(); for (final AbstractFactor pot : nodePotentials) { initializations.add(newPair(pot, flyweight.getInstance(pot.getValues().clone()))); } for (final AbstractFactor sep : sepSets.values()) { initializations.add(newPair(sep, flyweight.getInstance(sep.getValues().clone()))); } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy