org.drools.beliefs.bayes.BayesAbsorption Maven / Gradle / Ivy
/*
* Copyright 2015 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.beliefs.bayes;
import java.util.Arrays;
public class BayesAbsorption {
int[] srcVarPos;
int[] srcVarPosMultipliers;
double[] srcPotentials;
double[] oldSrcPotentials;
int srcPotentialIndex;
BayesVariable[] trgVars;
double[] trgPotentials;
int trgPotentialIndex;
int[] path;
public BayesAbsorption(int[] srcVarPos, double[] oldSrcPotentials, double[] srcPotentials, int[] srcVarPosMultipliers, BayesVariable[] trgVars, double[] trgPotentials) {
this.srcVarPos = srcVarPos;
this.srcVarPosMultipliers = srcVarPosMultipliers;
this.srcPotentials = srcPotentials;
this.oldSrcPotentials = oldSrcPotentials;
this.trgVars = trgVars;
this.trgPotentials = trgPotentials;
this.path = new int[trgVars.length];
}
public void absorb() {
srcPotentials = dividePotentials(srcPotentials, oldSrcPotentials);
trgPotentialIndex = 0;
srcPotentialIndex = 0;
recurse(0, 0);
normalize(trgPotentials);
}
public static void normalize(double[] array) {
double sum = 0.0;
for ( int i = 0, length = array.length; i < length; i++ ) {
sum += array[i];
}
for ( int i = 0, length = array.length; i < length; i++ ) {
array[i] = array[i]/sum;
}
}
public static double[] dividePotentials(double[] potentials, double[] oldPotentials) {
double[] absorbedPotentials = Arrays.copyOf( potentials, potentials.length);
for ( int i = 0; i < potentials.length; i++ ) {
double oldP = oldPotentials[i];
if (oldP == 0.0) {
absorbedPotentials[i] = 0.0;
//} else if (Math.abs(Yvalue) < 1e-20) { // is this necessary?
} else {
absorbedPotentials[i] /= oldP;
}
}
return absorbedPotentials;
}
public void recurse(int currentVar, int targetVarPos) {
int numberOfOutcomes = trgVars[currentVar].getOutcomes().length;
boolean isSepVar = false;
int nextParentKeyPos = targetVarPos;
if (this.srcVarPos.length > 0 && targetVarPos < this.srcVarPos.length && this.srcVarPos[targetVarPos] == currentVar) {
nextParentKeyPos++;
isSepVar = true;
}
for (int j = 0; j < numberOfOutcomes; j++) {
path[currentVar] = j;
if (currentVar < trgVars.length - 1) {
recurse(currentVar + 1, nextParentKeyPos);
} else {
//sum += srcPotentials[srcPotentialIndex];
trgPotentials[trgPotentialIndex++] *= srcPotentials[srcPotentialIndex];
}
if ( isSepVar ) {
srcPotentialIndex += srcVarPosMultipliers[targetVarPos];
}
}
if ( isSepVar ) {
srcPotentialIndex -= (srcVarPosMultipliers[targetVarPos] * numberOfOutcomes );
}
}
}