cc.kave.repackaged.jayes.inference.junctionTree.JunctionTreeBuilder Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cc.kave.repackaged.jayes Show documentation
Show all versions of cc.kave.repackaged.jayes Show documentation
Repackaging of Jayes (Eclipse Code Recommenders) to make it available in Maven.
The newest version!
/**
* Copyright (c) 2011 Michael Kutschke.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Michael Kutschke - initial API and implementation.
*/
package cc.kave.repackaged.jayes.inference.junctionTree;
import static cc.kave.repackaged.jayes.util.Pair.newPair;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.ListIterator;
import cc.kave.repackaged.jayes.BayesNet;
import cc.kave.repackaged.jayes.BayesNode;
import cc.kave.repackaged.jayes.internal.util.UnionFind;
import cc.kave.repackaged.jayes.util.Graph;
import cc.kave.repackaged.jayes.util.Pair;
import cc.kave.repackaged.jayes.util.Graph.Edge;
import cc.kave.repackaged.jayes.util.triangulation.GraphElimination;
import cc.kave.repackaged.jayes.util.triangulation.IEliminationHeuristic;
public class JunctionTreeBuilder {
private IEliminationHeuristic heuristic;
public static JunctionTreeBuilder forHeuristic(IEliminationHeuristic heuristic) {
return new JunctionTreeBuilder(heuristic);
}
protected JunctionTreeBuilder(IEliminationHeuristic heuristic) {
this.heuristic = heuristic;
}
public JunctionTree buildJunctionTree(BayesNet net) {
JunctionTree junctionTree = new JunctionTree(new Graph());
junctionTree.setClusters(triangulateGraphAndFindCliques(buildMoralGraph(net), weightNodesByOutcomes(net),
heuristic));
junctionTree.setSepSets(computeSepsets(junctionTree, net));
return junctionTree;
}
private Graph buildMoralGraph(BayesNet net) {
Graph moral = new Graph();
moral.initialize(net.getNodes().size());
for (final BayesNode node : net.getNodes()) {
addMoralEdges(moral, node);
}
return moral;
}
private void addMoralEdges(Graph moral, final BayesNode node) {
final ListIterator it = node.getParents().listIterator();
while (it.hasNext()) {
final BayesNode parent = it.next();
final ListIterator remainingParentsIt = node.getParents().listIterator(it.nextIndex());
while (remainingParentsIt.hasNext()) { // connect parents
final BayesNode otherParent = remainingParentsIt.next();
moral.addEdge(parent.getId(), otherParent.getId());
}
moral.addEdge(node.getId(), parent.getId());
}
}
private List> triangulateGraphAndFindCliques(Graph graph, double[] weights,
IEliminationHeuristic eliminationHeuristic) {
GraphElimination triangulate = new GraphElimination(graph, weights, eliminationHeuristic);
final List> cliques = new ArrayList>();
for (List nextClique : triangulate) {
if (!containsSuperset(cliques, nextClique)) {
cliques.add(nextClique);
}
}
return cliques;
}
private double[] weightNodesByOutcomes(BayesNet net) {
double[] weights = new double[net.getNodes().size()];
for (BayesNode node : net.getNodes()) {
weights[node.getId()] = Math.log(node.getOutcomeCount());
// using these weights is the same as minimizing the resulting cluster factor size
// which is given by the product of the variable outcome counts.
}
return weights;
}
private boolean containsSuperset(final Collection extends Collection> sets, final Collection set) {
boolean isSubsetOfOther = false;
for (final Collection superset : sets) {
if (superset.containsAll(set)) {
isSubsetOfOther = true;
break;
}
}
return isSubsetOfOther;
}
private List>> computeSepsets(JunctionTree junctionTree, BayesNet net) {
final List>> candidates = enumerateCandidateSepSets(junctionTree.getClusters());
Collections.sort(candidates, new SepsetComparator(net));
return computeMaxSpanningTree(junctionTree.getGraph(), candidates);
}
private List>> enumerateCandidateSepSets(List> clusters) {
final List>> sepSets = new ArrayList>>();
final ListIterator> it = clusters.listIterator();
while (it.hasNext()) {
final List clique1 = it.next();
final ListIterator> remainingIt = clusters.listIterator(it.nextIndex());
while (remainingIt.hasNext()) { // generate sepSets
final List clique2 = new ArrayList(remainingIt.next());
clique2.retainAll(clique1);
sepSets.add(newPair(new Edge(it.nextIndex() - 1, remainingIt.nextIndex() - 1), clique2));
}
}
return sepSets;
}
private List>> computeMaxSpanningTree(Graph graph,
final List>> sortedCandidateSepSets) {
final ArrayDeque>> pq = new ArrayDeque>>(
sortedCandidateSepSets);
final int vertexCount = graph.getAdjacency().size();
final UnionFind[] sets = UnionFind.createArray(vertexCount);
final List>> leftSepSets = new ArrayList>>();
while (leftSepSets.size() < (vertexCount - 1)) {
final Pair> sep = pq.poll();
final boolean bothEndsInSameTree = sets[sep.getFirst().getFirst()].find() == sets[sep.getFirst()
.getSecond()].find();
if (!bothEndsInSameTree) {
sets[sep.getFirst().getFirst()].merge(sets[sep.getFirst().getSecond()]);
leftSepSets.add(sep);
graph.addEdge(sep.getFirst().getFirst(), sep.getFirst().getSecond());
}
}
return leftSepSets;
}
private final class SepsetComparator implements Comparator>> {
private final BayesNet net;
public SepsetComparator(BayesNet net) {
this.net = net;
}
// heuristic: choose sepSet with most variables first,
// if equal, choose the on with least table size
@Override
public int compare(final Pair> sepSet1, final Pair> sepSet2) {
final int compareNumberOfVariables = compare(sepSet1.getSecond().size(), sepSet2.getSecond().size());
if (compareNumberOfVariables != 0) {
return -compareNumberOfVariables;
}
final int tableSize1 = getTableSize(sepSet1.getSecond());
final int tableSize2 = getTableSize(sepSet2.getSecond());
return compare(tableSize1, tableSize2);
}
private int getTableSize(final List cluster) {
int tableSize = 1;
for (final int id : cluster) {
tableSize *= net.getNode(id).getOutcomeCount();
}
return tableSize;
}
private int compare(final int i1, final int i2) {
return i1 - i2;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy