All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.integratedmodelling.engine.modelling.bayes.BayesianActuator Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (C) 2007, 2015:
 * 
 * - Ferdinando Villa  - integratedmodelling.org - any
 * other authors listed in @author annotations
 *
 * All rights reserved. This file is part of the k.LAB software suite, meant to enable
 * modular, collaborative, integrated development of interoperable data and model
 * components. For details, see http://integratedmodelling.org.
 * 
 * This program is free software; you can redistribute it and/or modify it under the terms
 * of the Affero General Public License Version 3 or any later version.
 *
 * This program is distributed in the hope that it will be useful, but without any
 * warranty; without even the implied warranty of merchantability or fitness for a
 * particular purpose. See the Affero General Public License for more details.
 * 
 * You should have received a copy of the Affero General Public License along with this
 * program; if not, write to the Free Software Foundation, Inc., 59 Temple Place - Suite
 * 330, Boston, MA 02111-1307, USA. The license is also available at:
 * https://www.gnu.org/licenses/agpl.html
 *******************************************************************************/
package org.integratedmodelling.engine.modelling.bayes;

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.integratedmodelling.api.data.IProbabilityDistribution;
import org.integratedmodelling.api.knowledge.IConcept;
import org.integratedmodelling.api.knowledge.IKnowledge;
import org.integratedmodelling.api.knowledge.IObservation;
import org.integratedmodelling.api.modelling.IActiveDirectObservation;
import org.integratedmodelling.api.modelling.IClassification;
import org.integratedmodelling.api.modelling.IClassifyingObserver;
import org.integratedmodelling.api.modelling.IConditionalObserver;
import org.integratedmodelling.api.modelling.IMediatingObserver;
import org.integratedmodelling.api.modelling.IModel;
import org.integratedmodelling.api.modelling.INumericObserver;
import org.integratedmodelling.api.modelling.IObservableSemantics;
import org.integratedmodelling.api.modelling.IObserver;
import org.integratedmodelling.api.modelling.IPresenceObserver;
import org.integratedmodelling.api.modelling.IProbabilityObserver;
import org.integratedmodelling.api.modelling.IUncertaintyObserver;
import org.integratedmodelling.api.modelling.IValueResolver;
import org.integratedmodelling.api.modelling.resolution.IResolutionScope;
import org.integratedmodelling.api.modelling.scheduling.ITransition;
import org.integratedmodelling.api.monitoring.IMonitor;
import org.integratedmodelling.api.project.IProject;
import org.integratedmodelling.api.services.annotations.Prototype;
import org.integratedmodelling.collections.Pair;
import org.integratedmodelling.collections.Triple;
import org.integratedmodelling.common.data.IndexedCategoricalDistribution;
import org.integratedmodelling.common.model.runtime.AbstractStateContextualizer;
import org.integratedmodelling.common.utils.CamelCase;
import org.integratedmodelling.common.vocabulary.NS;
import org.integratedmodelling.common.vocabulary.ObservableSemantics;
import org.integratedmodelling.common.vocabulary.Observables;
import org.integratedmodelling.exceptions.KlabException;
import org.integratedmodelling.exceptions.KlabRuntimeException;
import org.integratedmodelling.exceptions.KlabValidationException;

import com.google.common.collect.Sets;

@Prototype(
        id = "bayesian",
        args = { "import", Prototype.TEXT, "# method", Prototype.TEXT },
        returnTypes = {
                NS.STATE_CONTEXTUALIZER })
public class BayesianActuator extends AbstractStateContextualizer
        implements IValueResolver {

    IBayesianNetwork                                network;
    String                                          importFile;
    File                                            workspace;
    IBayesianInference                              inference;

    HashSet                                 nodeIds                 = new HashSet();
    HashMap                         key2node                = new HashMap();
    HashMap                     concept2node            = new HashMap();
    HashMap>                   outputKeys              = new HashMap>();
    HashMap                outputClassifications   = new HashMap();
    // keep keys of presence/absence observers with the node ID and the outcome
    // IDs
    // corresponding to true and
    // false.
    HashMap> presenceKeys            = new HashMap>();
    Set                                     warningKeys             = new HashSet();
    Set                                     typewarnKeys            = new HashSet();

    /*
     * if the observer is a probability linked to a specific outcome, these are set.
     */
    String                                          probabilityNode         = null;
    String                                          probabilityOutcome      = null;
    int                                             probabilityOutcomeIndex = -1;

    class UncertaintyDesc {
        public UncertaintyDesc(IObservableSemantics observable, IKnowledge inherentType) {
            this.observed = inherentType;
            this.observable = observable;
        }

        // the key for the correspondent BN node - starts null
        String               nodeKey;
        // the observable for the uncertainty
        IObservableSemantics observable;
        // the concept we're looking at
        IKnowledge           observed;
    }

    // matches the key we use for the uncertainty state to the node we're
    // measuring the
    // uncertainty for.
    HashMap _uncertainties = new HashMap();

    boolean                          resolved       = false;

    public BayesianActuator() {
        super(null);
    }

    public BayesianActuator(String importFile, File workspace, IMonitor monitor)
            throws KlabException {
        super(monitor);
        this.importFile = importFile;
        this.workspace = workspace;
    }

    @Override
    public void setContext(Map parameters, IModel model, IProject project) {
        this.importFile = parameters.get("import").toString();
        this.workspace = project.getLoadPath();
    }

    @Override
    public Map define(String name, IObserver observer, IActiveDirectObservation contextSubject, IResolutionScope context, Map expectedInputs, Map expectedOutputs, boolean isLastInChain, IMonitor monitor)
            throws KlabException {

        this.network = BayesianFactory.get()
                .createBayesianNetwork(workspace + File.separator + importFile);
        nodeIds = Sets.newHashSet(this.network.getAllNodeIds());
        this.inference = this.network.getInference();

        Map states = super.define(name, observer, contextSubject, context, expectedInputs, expectedOutputs, isLastInChain, monitor);

        for (String ikey : expectedInputs.keySet()) {
            notifyInput(expectedInputs.get(ikey), expectedInputs.get(ikey)
                    .getObserver(), ikey);
        }

        for (String okey : expectedOutputs.keySet()) {
            notifyOutput(expectedOutputs.get(okey), expectedOutputs.get(okey)
                    .getObserver(), okey, okey.equals(name));
        }

        return states;
    }

    public void notifyInput(IObservableSemantics observable, IObserver observer, String key)
            throws KlabException {

        /*
         * check if it's a presence/absence; special treatment if so.
         */
        if (observer instanceof IPresenceObserver) {

            String nodeId = findMatchingNodeID(observable, key);
            if (nodeId != null) {

                String present = null, absent = null;

                for (String s : this.network.getOutcomeIds(nodeId)) {
                    if (s.endsWith("Present")) {
                        present = s;
                    }
                    if (s.endsWith("Absent")) {
                        absent = s;
                    }
                }

                if (present == null || absent == null) {
                    monitor.error("cannot establish outcomes for presence/absence of "
                            + nodeId);
                } else {
                    presenceKeys
                            .put(key, new Triple(nodeId, present, absent));
                }
            }
            return;
        }

        /*
         * if the observer is not discretized in some way, raise a ruckus and leave.
         */
        // if (!((Observer)observer).isDiscrete()) {
        IClassification classif = getClassification(observer);
        if (classif == null) {
            monitor.error("cannot obtain discretized values from observation of "
                    + observable);
            return;
        }

        /*
         * find a matching node and set it in key2node dictionary. TODO if not found,
         * ignore for now - will want to warn later
         */
        String nodeId = findMatchingNodeID(observable, key);

        if (nodeId == null)
            return;

        /*
         * validate state IDs against concepts
         */
        String notFound = "";
        for (String s : this.network.getOutcomeIds(nodeId)) {
            boolean match = false;
            for (IConcept c : classif.getConceptOrder()) {
                if ((match = c.getLocalName().equals(s))) {
                    break;
                }
            }
            if (!match) {
                notFound += (notFound.isEmpty() ? "" : ", ") + s;
            }
        }

        if (!notFound.isEmpty()) {
            monitor.error("cannot match subclasses of " + observable.getType()
                    + " to outcomes of bayesian node "
                    + nodeId + ": " + notFound);
        }
    }

    // @Override
    public void notifyOutput(IObservableSemantics observable, IObserver observer, String key, boolean isMain)
            throws KlabException {

        if (observer instanceof IUncertaintyObserver) {
            /*
             * prepare to handle uncertainty; we match concepts to nodes when we have all
             * nodes.
             */
            _uncertainties
                    .put(key, new UncertaintyDesc(observable, ((IUncertaintyObserver) observer)
                            .getOriginalConcept()));
        } else if (observer instanceof IProbabilityObserver) {

            /*
             * must find the node and the outcome. For the outcome ID, get the event and
             * if the node has "present/absent" as ID, use present; else, lookup
             *  from concept description - e.g. HighTemperature - from
             * .
             */
            IProbabilityObserver pobs = (IProbabilityObserver) observer;
            IConcept event = pobs.getEventType();

            if (event == null) {
                throw new KlabValidationException("invalid probability semantics: cannot establish event type");
            }

            if (nodeIds.contains(event.getLocalName())
                    || nodeIds.contains(event.getLocalName().toLowerCase())) {

                String pnode = nodeIds.contains(event.getLocalName()) ? event.getLocalName()
                        : event.getLocalName().toLowerCase();

                String poutcome = null;
                int poutcomeIdx = -1;

                String[] ids = network.getOutcomeIds(pnode);
                for (int i = 0; i < ids.length; i++) {
                    if (ids[i].toLowerCase().equals("present") || ids[i].toLowerCase().equals((event.getLocalName() + "present").toLowerCase())) {
                        poutcome = ids[i];
                        poutcomeIdx = i;
                        break;
                    }
                }

                if (poutcome != null) {
                    this.probabilityNode = pnode;
                    this.probabilityOutcome = poutcome;
                    this.probabilityOutcomeIndex = poutcomeIdx;
                }

            }

        } else {

            /*
             * find node matching observable. TODO handle classifications by trait
             * properly.
             */
            IClassification cls = getClassification(observer);
            String nodeId = null;

            if (cls != null) {
                nodeId = findMatchingNodeID(observable, key);
                if (nodeId == null) {
                    monitor.warn("bayesian: cannot find a node to match output "
                            + cls.getConceptSpace() + " (" + key
                            + "): output will not be computed by Bayesian model");
                    return;
                }
            } else {
                // monitor.error("bayesian: observed values of " + key + " are
                // not
                // discretized");
                return;
            }

            /*
             * store classification as key to generate distributions later
             */
            outputClassifications.put(nodeId, cls);

            /*
             * build ID order to interpret classification into outcomes
             */
            List outcomeOrder = new ArrayList();

            String missing = "";
            HashSet outcomes = Sets
                    .newHashSet(this.network.getOutcomeIds(nodeId));
            for (IConcept c : cls.getConceptOrder()) {
                String id = c.getLocalName();
                if (!outcomes.contains(id)) {
                    missing += (missing.isEmpty() ? "" : ", ") + id;
                }
                outcomeOrder.add(id);
            }

            if (!missing.isEmpty()) {
                monitor.error("bayesian: cannot find outcome(s): " + missing + " in node "
                        + nodeId
                        + " to match observable " + observable.getLocalName() + " (" + key
                        + ")");
                return;
            }

            /*
             * record key for evidence matching
             */
            outputKeys.put(nodeId, outcomeOrder);
        }
    }

    private String findMatchingNodeID(IObservableSemantics observable, String key) {

        String ret = null;

        if (key != null) {
            String humpKey = CamelCase.toUpperCamelCase(key, '-');
            if (nodeIds.contains(key)) {
                ret = key;
            } else if (nodeIds.contains(humpKey)) {
                ret = humpKey;
            }
        }

        if (ret == null) {

            if (observable.getFormalName() != null) {
                String humpKey = CamelCase.toUpperCamelCase(key, '-');
                for (String nid : nodeIds) {
                    if (nid.equals(observable.getFormalName()) || nid.equals(humpKey)) {
                        ret = nid;
                        break;
                    }
                }
            }
        }

        if (ret == null && ((ObservableSemantics) observable).getLocalType() != null) {
            List candidateIds = Observables
                    .getIdentifiersFor(((ObservableSemantics) observable).getLocalType());
            for (String id : candidateIds) {
                for (String nid : nodeIds) {
                    if (nid.equals(id)) {
                        ret = nid;
                        break;
                    }
                }
                if (ret != null) {
                    break;
                }
            }
        }

        if (ret == null) {
            List candidateIds = Observables
                    .getIdentifiersFor(observable.getType());
            for (String id : candidateIds) {
                for (String nid : nodeIds) {
                    if (nid.equals(id)) {
                        ret = nid;
                        break;
                    }
                }
                if (ret != null) {
                    break;
                }
            }
        }

        if (ret != null) {
            key2node.put(key, ret);
            concept2node.put(observable.getType(), ret);
        }

        return ret;
    }

    private IClassification getClassification(IObserver observer) {

        IClassification ret = null;
        if (observer instanceof IClassifyingObserver) {
            ret = ((IClassifyingObserver) observer).getClassification();
        } else if (observer instanceof INumericObserver) {
            ret = ((INumericObserver) observer).getDiscretization();
        } else if (observer instanceof IConditionalObserver) {

            /*
             * classifications must be the same if present, so just get the first
             * FIXME/CHECK: should use getRepresentativeObserver although I don't think it
             * makes a difference now. No test case so postponing.
             */
            ret = getClassification(((IConditionalObserver) observer).getModels().get(0)
                    .getFirst().getObserver());
        }

        if (ret == null && observer instanceof IMediatingObserver) {
            ret = getClassification(((IMediatingObserver) observer)
                    .getMediatedObserver());
        }

        return ret;
    }

    /*
     * runs once before the first process(). Resolves all uncertainty keys to the
     * correspondent concept.
     */
    private void resolveUncertaintyRefs() throws KlabValidationException {
        resolved = true;
        for (UncertaintyDesc u : _uncertainties.values()) {
            for (IKnowledge c : concept2node.keySet()) {
                if (u.observed.is(c)) {
                    u.nodeKey = concept2node.get(c);
                    break;
                }
            }
            // u.nodeKey = _concept2node.get(u.observed);
            if (u.nodeKey == null) {
                throw new KlabValidationException("cannot find concept " + u.observed
                        + " for uncertainty computation");
            }
        }
    }

    public Map run(Map inputs, ITransition transition)
            throws KlabException {

        ArrayList> evidence = new ArrayList>();

        if (!resolved) {
            resolveUncertaintyRefs();
        }

        this.inference.clearEvidence();

        Map ret = new HashMap<>();

        for (String inputKey : getInputKeys()) {

            Object value = inputs.get(inputKey);

            if (presenceKeys.containsKey(inputKey)) {

                if (!(value instanceof Boolean) && value != null) {
                    throw new KlabRuntimeException("internal: presence value not a boolean for "
                            + inputKey);
                }

                Triple ik = presenceKeys.get(inputKey);
                if (value != null) {
                    evidence.add(new Pair(ik.getFirst(), (Boolean) value
                            ? ik.getSecond() : ik.getThird()));
                }

                continue;
            }

            String nodeId = key2node.get(inputKey);

            if (nodeId == null && !warningKeys.contains(inputKey)) {
                monitor.warn("model dependency " + inputKey
                        + " cannot be matched to any Bayesian node");
                warningKeys.add(inputKey);
            }

            if (value != null && !(value instanceof IConcept)) {

                if (value instanceof IProbabilityDistribution) {
                    IObserver observer = getInputObservers().get(inputKey);
                    IClassification classification = getClassification(observer);
                    IProbabilityDistribution distribution = (IProbabilityDistribution) value;
                    if (classification.getConceptOrder()
                            .size() == distribution.getData().length) {
                        if (!typewarnKeys.contains(inputKey)) {
                            monitor.warn("input " + inputKey
                                    + " is probabilistic: using most likely category in input as evidence");
                            typewarnKeys.add(inputKey);
                        }
                        IConcept c = classification.getConceptOrder().get(0);
                        double max = distribution.getData()[0];
                        for (int i = 0; i < distribution.getData().length; i++) {
                            if (distribution.getData()[i] > max) {
                                c = classification.getConceptOrder().get(i);
                            }
                        }
                        value = c;
                    } else {
                        if (!typewarnKeys.contains(inputKey)) {
                            monitor.error("incompatible probabilistic input for node "
                                    + inputKey);
                            typewarnKeys.add(inputKey);
                        }
                    }
                }

                if (!typewarnKeys.contains(inputKey)) {
                    monitor.warn("ignoring non-categorical value for " + inputKey
                            + " as Bayesian evidence");
                    typewarnKeys.add(inputKey);
                }
            }

            /*
             * silent if value is not a concept for any reason, but behaves nicely on
             * nodata.
             */
            if (nodeId != null && value instanceof IConcept) {
                evidence.add(new Pair(nodeId, ((IConcept) value)
                        .getLocalName()));
            }
        }

        /*
         * submit evidence
         */
        for (Pair zio : evidence) {
            this.inference.setEvidence(zio.getFirst(), zio.getSecond());
        }

        /*
         * run inference
         */
        this.inference.run();

        /*
         * put values back
         */
        for (String outputKey : getOutputKeys()) {

            String nodeId = null;
            boolean isUncertainty = false;

            if (_uncertainties.containsKey(outputKey)) {
                nodeId = _uncertainties.get(outputKey).nodeKey;
                isUncertainty = true;
            } else {
                nodeId = probabilityNode == null ? this.key2node.get(outputKey) : probabilityNode;
            }

            if (probabilityNode != null) {
                
                ret.put(outputKey, this.inference.getMarginal(probabilityNode, probabilityOutcome));
                
            } else {
            
            List okey = this.outputKeys.get(nodeId);

            /*
             * happens after initialization gave errors - which should interrupt the
             * process, but if not (e.g. parallelizing), this avoids a NPE.
             */
            if (nodeId == null || okey == null) {
                continue;
            }

            double[] data = new double[okey.size()];
            int i = 0;
            for (String outcome : okey) {
                data[i++] = this.inference.getMarginal(nodeId, outcome);
            }

            Object value = new IndexedCategoricalDistribution(data, this.outputClassifications
                    .get(nodeId).getDistributionBreakpoints());

            if (isUncertainty) {
                value = ((IndexedCategoricalDistribution) ret).getUncertainty();
            }

            ret.put(outputKey, value);
            }
        }

        return ret;
    }

    @Override
    public String toString() {
        return "Bayesian network " + (this.network == null ? "" : this.network.getName());
    }

    @Override
    public boolean isProbabilistic() {
        return probabilityOutcome == null;
    }

    @Override
    public Map initialize(int index, Map inputs)
            throws KlabException {
        return run(inputs, ITransition.INITIALIZATION);
    }

    @Override
    public Map compute(int index, ITransition transition, Map inputs)
            throws KlabException {
        return run(inputs, transition);
    }

    @Override
    public String getLabel() {
        return "Bayesian inference";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy