All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.drools.beliefs.bayes.BayesLikelyhood Maven / Gradle / Ivy

There is a newer version: 9.44.0.Final
Show newest version
/*
 * Copyright 2015 Red Hat, Inc. and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
*/

package org.drools.beliefs.bayes;

import org.drools.beliefs.graph.Graph;
import org.drools.beliefs.graph.GraphNode;

import java.util.Arrays;

public class BayesLikelyhood {
    private BayesVariable variable;
    private double[]      distribution;
    private double[][]    varPotential;

    private BayesVariable[] vars;
    private int[]           multipliers;
    private int[]           parentVarPos;
    private int[]           parentIndexMultipliers;
    private int             varPos;

    public BayesLikelyhood(Graph graph, JunctionTreeClique jtNode, GraphNode varNode, double[] distribution) {
        vars = jtNode.getValues().toArray( new BayesVariable[jtNode.getValues().size()] );

        this.variable = varNode.getContent();
        this.distribution = distribution;

        BayesVariable[] parents = new BayesVariable[varNode.getInEdges().size()];
        for ( int i = 0; i < parents.length; i++ ) {
            parents[i] = (BayesVariable) varNode.getInEdges().get(i).getOutGraphNode().getContent();
        }

        varPotential = new double[ this.variable.getProbabilityTable().length ][];
        for ( int i = 0; i < varPotential.length; i++ ) {
            varPotential[i] = distribution;
        }

        BayesVariable[] vars = jtNode.getValues().toArray( new BayesVariable[jtNode.getValues().size()] );
        int numberOfStates = PotentialMultiplier.createNumberOfStates(vars);
        multipliers = PotentialMultiplier.createIndexMultipliers(vars, numberOfStates);

        parentVarPos = PotentialMultiplier.createSubsetVarPos(vars, parents);
        int parentsNumberOfStates = PotentialMultiplier.createNumberOfStates(parents);
        parentIndexMultipliers = PotentialMultiplier.createIndexMultipliers(parents, parentsNumberOfStates);
//
        varPos = -1;
        for( int i = 0; i < vars.length; i++) {
            if ( vars[i] == variable)  {
                varPos = i;
                break;
            }
        }
        if ( varPos == -1 || varPos == vars.length ) {
            throw new IllegalStateException( "Unable to find Variable in set" );
        }
    }

    public BayesVariable getVariable() {
        return variable;
    }

    public double[] getDistribution() {
        return distribution;
    }

    public void multiplyInto(double[] trgPotential) {
        PotentialMultiplier m = new PotentialMultiplier(varPotential, varPos, parentVarPos, parentIndexMultipliers,
                                                        vars, multipliers, trgPotential);

        m.multiple();

//        int j = 0;
//        int jlength = distribution.length;
//        for ( int i = 0, length = potential.length; i < length; i++ ) {
//            potential[i] = potential[i] * distribution[j++];
//            if ( j == jlength ) {
//                j = 0;
//            }
//        }

    }

    @Override
    public boolean equals(Object o) {
        if (this == o) { return true; }
        if (o == null || getClass() != o.getClass()) { return false; }

        BayesLikelyhood that = (BayesLikelyhood) o;

        if (!Arrays.equals(distribution, that.distribution)) { return false; }
        if (!variable.equals(that.variable)) { return false; }

        return true;
    }

    @Override
    public int hashCode() {
        int result = variable.hashCode();
        result = 31 * result + Arrays.hashCode(distribution);
        return result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy