org.drools.beliefs.bayes.SeparatorSet Maven / Gradle / Ivy
/*
* Copyright 2015 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.beliefs.bayes;
import org.drools.beliefs.graph.Graph;
import org.drools.beliefs.graph.GraphNode;
import org.drools.core.util.bitmask.OpenBitSet;
public class SeparatorSet implements Comparable {
private int id1;
private OpenBitSet clique1;
private int id2;
private OpenBitSet clique2;
private OpenBitSet intersection;
private int mass; // number of vertices that are in the intersects of clique1 and clique2
private int cost; // product of the weight of the vertices that are in the intersects of clique1 and clique2
public SeparatorSet(OpenBitSet clique1, int id1, OpenBitSet clique2, int id2, Graph graph) {
this.id1 = id1;
this.clique1 = clique1;
this.clique2 = clique2;
this.id2 = id2;
intersection = (OpenBitSet) clique1.clone();
intersection.and(clique2);
mass = (int) intersection.cardinality();
cost = 1;
if (mass > 0) {
for (int i = intersection.nextSetBit(0); i >= 0; i = intersection.nextSetBit(i + 1)) {
GraphNode v = graph.getNode(i);
cost *= Math.abs(v.getContent().getOutcomes().length);
}
}
}
public int getId1() {
return id1;
}
public OpenBitSet getClique1() {
return clique1;
}
public OpenBitSet getClique2() {
return clique2;
}
public int getId2() {
return id2;
}
public OpenBitSet getIntersection() {
return intersection;
}
public int getMass() {
return mass;
}
public int getCost() {
return cost;
}
@Override
public int compareTo(SeparatorSet o) {
if (this == o) { return 0; }
if (mass != o.mass) {
return o.mass - mass;
} else if (cost != o.cost) {
return o.cost - cost;
} else {
// int j = o.intersection.nextSetBit(0);
// for (int i = intersection.nextSetBit(0); i >= 0; i = intersection.nextSetBit(i + 1)) {
// if (i != j) {
// return i - j;
// }
// j = o.intersection.nextSetBit(j + 1);
//
// }
// while they are the same, we want the arbitrary result to be deterministic
// in this case iterate the pairs and return the one that has the fist bit set and the other doesn't.
int j = o.clique1.nextSetBit(0);
for (int i = clique1.nextSetBit(0); i >= 0; i = clique1.nextSetBit(i + 1)) {
if (i != j) {
return i - j;
}
j = o.clique1.nextSetBit(j + 1);
}
j = o.clique2.nextSetBit(0);
for (int i = clique2.nextSetBit(0); i >= 0; i = clique2.nextSetBit(i + 1)) {
if (i != j) {
return i - j;
}
j = o.clique2.nextSetBit(j + 1);
}
}
return 0; // the two pairs of cliques are the same
}
@Override
public boolean equals(Object o) {
if (this == o) { return true; }
if (o == null || getClass() != o.getClass()) { return false; }
SeparatorSet separatorSet = (SeparatorSet) o;
if (cost != separatorSet.cost) { return false; }
if (mass != separatorSet.mass) { return false; }
if (!clique1.equals(separatorSet.clique1)) { return false; }
if (!clique2.equals(separatorSet.clique2)) { return false; }
if (!intersection.equals(separatorSet.intersection)) { return false; }
return true;
}
@Override
public int hashCode() {
int result = clique1.hashCode();
result = 31 * result + clique2.hashCode();
result = 31 * result + intersection.hashCode();
result = 31 * result + mass;
result = 31 * result + cost;
return result;
}
@Override
public String toString() {
return "SepSet{" +
"clique1=" + clique1 +
", clique2=" + clique2 +
", intersection=" + intersection +
", mass=" + mass +
", cost=" + cost +
'}';
}
}