All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.drools.beliefs.bayes.BayesAbsorption Maven / Gradle / Ivy

There is a newer version: 9.44.0.Final
Show newest version
/*
 * Copyright 2015 Red Hat, Inc. and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
*/

package org.drools.beliefs.bayes;

import java.util.Arrays;

public class BayesAbsorption {
    int[]           srcVarPos;
    int[]           srcVarPosMultipliers;
    double[]        srcPotentials;
    double[]        oldSrcPotentials;
    int             srcPotentialIndex;

    BayesVariable[] trgVars;
    double[]        trgPotentials;
    int             trgPotentialIndex;

    int[]           path;


    public BayesAbsorption(int[] srcVarPos, double[] oldSrcPotentials, double[] srcPotentials, int[] srcVarPosMultipliers, BayesVariable[] trgVars, double[] trgPotentials) {
        this.srcVarPos = srcVarPos;
        this.srcVarPosMultipliers = srcVarPosMultipliers;
        this.srcPotentials = srcPotentials;
        this.oldSrcPotentials = oldSrcPotentials;

        this.trgVars = trgVars;
        this.trgPotentials = trgPotentials;
        this.path = new int[trgVars.length];
    }

    public void absorb() {
        srcPotentials = dividePotentials(srcPotentials, oldSrcPotentials);

        trgPotentialIndex = 0;
        srcPotentialIndex = 0;
        recurse(0, 0);
        normalize(trgPotentials);
    }

    public static void normalize(double[] array) {
        double sum = 0.0;
        for ( int i = 0, length = array.length; i < length; i++ ) {
            sum += array[i];
        }

        for ( int i = 0, length = array.length; i < length; i++ ) {
            array[i] = array[i]/sum;
        }
    }

    public static double[] dividePotentials(double[] potentials, double[] oldPotentials) {
        double[] absorbedPotentials = Arrays.copyOf( potentials, potentials.length);
        for ( int i = 0; i < potentials.length; i++ ) {
            double oldP = oldPotentials[i];
            if (oldP == 0.0) {
                absorbedPotentials[i] = 0.0;
            //} else if (Math.abs(Yvalue) < 1e-20) { // is this necessary?
            } else {
                absorbedPotentials[i] /= oldP;
            }
        }
        return absorbedPotentials;
    }


    public void recurse(int currentVar, int targetVarPos) {
        int numberOfOutcomes = trgVars[currentVar].getOutcomes().length;

        boolean isSepVar = false;
        int nextParentKeyPos = targetVarPos;
        if (this.srcVarPos.length > 0 && targetVarPos < this.srcVarPos.length &&  this.srcVarPos[targetVarPos] == currentVar) {
            nextParentKeyPos++;
            isSepVar = true;
        }

        for (int j = 0; j < numberOfOutcomes; j++) {
            path[currentVar] = j;

            if (currentVar < trgVars.length - 1) {
                recurse(currentVar + 1, nextParentKeyPos);
            } else {
                //sum += srcPotentials[srcPotentialIndex];
                trgPotentials[trgPotentialIndex++] *= srcPotentials[srcPotentialIndex];
            }
            if ( isSepVar ) {
                srcPotentialIndex += srcVarPosMultipliers[targetVarPos];
            }
        }
        if ( isSepVar ) {
            srcPotentialIndex -= (srcVarPosMultipliers[targetVarPos] * numberOfOutcomes );
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy