org.integratedmodelling.engine.modelling.bayes.BayesianActuator Maven / Gradle / Ivy
The newest version!
/*******************************************************************************
* Copyright (C) 2007, 2015:
*
* - Ferdinando Villa - integratedmodelling.org - any
* other authors listed in @author annotations
*
* All rights reserved. This file is part of the k.LAB software suite, meant to enable
* modular, collaborative, integrated development of interoperable data and model
* components. For details, see http://integratedmodelling.org.
*
* This program is free software; you can redistribute it and/or modify it under the terms
* of the Affero General Public License Version 3 or any later version.
*
* This program is distributed in the hope that it will be useful, but without any
* warranty; without even the implied warranty of merchantability or fitness for a
* particular purpose. See the Affero General Public License for more details.
*
* You should have received a copy of the Affero General Public License along with this
* program; if not, write to the Free Software Foundation, Inc., 59 Temple Place - Suite
* 330, Boston, MA 02111-1307, USA. The license is also available at:
* https://www.gnu.org/licenses/agpl.html
*******************************************************************************/
package org.integratedmodelling.engine.modelling.bayes;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.integratedmodelling.api.data.IProbabilityDistribution;
import org.integratedmodelling.api.knowledge.IConcept;
import org.integratedmodelling.api.knowledge.IKnowledge;
import org.integratedmodelling.api.knowledge.IObservation;
import org.integratedmodelling.api.modelling.IActiveDirectObservation;
import org.integratedmodelling.api.modelling.IClassification;
import org.integratedmodelling.api.modelling.IClassifyingObserver;
import org.integratedmodelling.api.modelling.IConditionalObserver;
import org.integratedmodelling.api.modelling.IMediatingObserver;
import org.integratedmodelling.api.modelling.IModel;
import org.integratedmodelling.api.modelling.INumericObserver;
import org.integratedmodelling.api.modelling.IObservableSemantics;
import org.integratedmodelling.api.modelling.IObserver;
import org.integratedmodelling.api.modelling.IPresenceObserver;
import org.integratedmodelling.api.modelling.IProbabilityObserver;
import org.integratedmodelling.api.modelling.IUncertaintyObserver;
import org.integratedmodelling.api.modelling.IValueResolver;
import org.integratedmodelling.api.modelling.resolution.IResolutionScope;
import org.integratedmodelling.api.modelling.scheduling.ITransition;
import org.integratedmodelling.api.monitoring.IMonitor;
import org.integratedmodelling.api.project.IProject;
import org.integratedmodelling.api.services.annotations.Prototype;
import org.integratedmodelling.collections.Pair;
import org.integratedmodelling.collections.Triple;
import org.integratedmodelling.common.data.IndexedCategoricalDistribution;
import org.integratedmodelling.common.model.runtime.AbstractStateContextualizer;
import org.integratedmodelling.common.utils.CamelCase;
import org.integratedmodelling.common.vocabulary.NS;
import org.integratedmodelling.common.vocabulary.ObservableSemantics;
import org.integratedmodelling.common.vocabulary.Observables;
import org.integratedmodelling.exceptions.KlabException;
import org.integratedmodelling.exceptions.KlabRuntimeException;
import org.integratedmodelling.exceptions.KlabValidationException;
import com.google.common.collect.Sets;
@Prototype(
id = "bayesian",
args = { "import", Prototype.TEXT, "# method", Prototype.TEXT },
returnTypes = {
NS.STATE_CONTEXTUALIZER })
public class BayesianActuator extends AbstractStateContextualizer
implements IValueResolver {
IBayesianNetwork network;
String importFile;
File workspace;
IBayesianInference inference;
HashSet nodeIds = new HashSet();
HashMap key2node = new HashMap();
HashMap concept2node = new HashMap();
HashMap> outputKeys = new HashMap>();
HashMap outputClassifications = new HashMap();
// keep keys of presence/absence observers with the node ID and the outcome
// IDs
// corresponding to true and
// false.
HashMap> presenceKeys = new HashMap>();
Set warningKeys = new HashSet();
Set typewarnKeys = new HashSet();
/*
* if the observer is a probability linked to a specific outcome, these are set.
*/
String probabilityNode = null;
String probabilityOutcome = null;
int probabilityOutcomeIndex = -1;
class UncertaintyDesc {
public UncertaintyDesc(IObservableSemantics observable, IKnowledge inherentType) {
this.observed = inherentType;
this.observable = observable;
}
// the key for the correspondent BN node - starts null
String nodeKey;
// the observable for the uncertainty
IObservableSemantics observable;
// the concept we're looking at
IKnowledge observed;
}
// matches the key we use for the uncertainty state to the node we're
// measuring the
// uncertainty for.
HashMap _uncertainties = new HashMap();
boolean resolved = false;
public BayesianActuator() {
super(null);
}
public BayesianActuator(String importFile, File workspace, IMonitor monitor)
throws KlabException {
super(monitor);
this.importFile = importFile;
this.workspace = workspace;
}
@Override
public void setContext(Map parameters, IModel model, IProject project) {
this.importFile = parameters.get("import").toString();
this.workspace = project.getLoadPath();
}
@Override
public Map define(String name, IObserver observer, IActiveDirectObservation contextSubject, IResolutionScope context, Map expectedInputs, Map expectedOutputs, boolean isLastInChain, IMonitor monitor)
throws KlabException {
this.network = BayesianFactory.get()
.createBayesianNetwork(workspace + File.separator + importFile);
nodeIds = Sets.newHashSet(this.network.getAllNodeIds());
this.inference = this.network.getInference();
Map states = super.define(name, observer, contextSubject, context, expectedInputs, expectedOutputs, isLastInChain, monitor);
for (String ikey : expectedInputs.keySet()) {
notifyInput(expectedInputs.get(ikey), expectedInputs.get(ikey)
.getObserver(), ikey);
}
for (String okey : expectedOutputs.keySet()) {
notifyOutput(expectedOutputs.get(okey), expectedOutputs.get(okey)
.getObserver(), okey, okey.equals(name));
}
return states;
}
public void notifyInput(IObservableSemantics observable, IObserver observer, String key)
throws KlabException {
/*
* check if it's a presence/absence; special treatment if so.
*/
if (observer instanceof IPresenceObserver) {
String nodeId = findMatchingNodeID(observable, key);
if (nodeId != null) {
String present = null, absent = null;
for (String s : this.network.getOutcomeIds(nodeId)) {
if (s.endsWith("Present")) {
present = s;
}
if (s.endsWith("Absent")) {
absent = s;
}
}
if (present == null || absent == null) {
monitor.error("cannot establish outcomes for presence/absence of "
+ nodeId);
} else {
presenceKeys
.put(key, new Triple(nodeId, present, absent));
}
}
return;
}
/*
* if the observer is not discretized in some way, raise a ruckus and leave.
*/
// if (!((Observer)observer).isDiscrete()) {
IClassification classif = getClassification(observer);
if (classif == null) {
monitor.error("cannot obtain discretized values from observation of "
+ observable);
return;
}
/*
* find a matching node and set it in key2node dictionary. TODO if not found,
* ignore for now - will want to warn later
*/
String nodeId = findMatchingNodeID(observable, key);
if (nodeId == null)
return;
/*
* validate state IDs against concepts
*/
String notFound = "";
for (String s : this.network.getOutcomeIds(nodeId)) {
boolean match = false;
for (IConcept c : classif.getConceptOrder()) {
if ((match = c.getLocalName().equals(s))) {
break;
}
}
if (!match) {
notFound += (notFound.isEmpty() ? "" : ", ") + s;
}
}
if (!notFound.isEmpty()) {
monitor.error("cannot match subclasses of " + observable.getType()
+ " to outcomes of bayesian node "
+ nodeId + ": " + notFound);
}
}
// @Override
public void notifyOutput(IObservableSemantics observable, IObserver observer, String key, boolean isMain)
throws KlabException {
if (observer instanceof IUncertaintyObserver) {
/*
* prepare to handle uncertainty; we match concepts to nodes when we have all
* nodes.
*/
_uncertainties
.put(key, new UncertaintyDesc(observable, ((IUncertaintyObserver) observer)
.getOriginalConcept()));
} else if (observer instanceof IProbabilityObserver) {
/*
* must find the node and the outcome. For the outcome ID, get the event and
* if the node has "present/absent" as ID, use present; else, lookup
* from concept description - e.g. HighTemperature - from
* .
*/
IProbabilityObserver pobs = (IProbabilityObserver) observer;
IConcept event = pobs.getEventType();
if (event == null) {
throw new KlabValidationException("invalid probability semantics: cannot establish event type");
}
if (nodeIds.contains(event.getLocalName())
|| nodeIds.contains(event.getLocalName().toLowerCase())) {
String pnode = nodeIds.contains(event.getLocalName()) ? event.getLocalName()
: event.getLocalName().toLowerCase();
String poutcome = null;
int poutcomeIdx = -1;
String[] ids = network.getOutcomeIds(pnode);
for (int i = 0; i < ids.length; i++) {
if (ids[i].toLowerCase().equals("present") || ids[i].toLowerCase().equals((event.getLocalName() + "present").toLowerCase())) {
poutcome = ids[i];
poutcomeIdx = i;
break;
}
}
if (poutcome != null) {
this.probabilityNode = pnode;
this.probabilityOutcome = poutcome;
this.probabilityOutcomeIndex = poutcomeIdx;
}
}
} else {
/*
* find node matching observable. TODO handle classifications by trait
* properly.
*/
IClassification cls = getClassification(observer);
String nodeId = null;
if (cls != null) {
nodeId = findMatchingNodeID(observable, key);
if (nodeId == null) {
monitor.warn("bayesian: cannot find a node to match output "
+ cls.getConceptSpace() + " (" + key
+ "): output will not be computed by Bayesian model");
return;
}
} else {
// monitor.error("bayesian: observed values of " + key + " are
// not
// discretized");
return;
}
/*
* store classification as key to generate distributions later
*/
outputClassifications.put(nodeId, cls);
/*
* build ID order to interpret classification into outcomes
*/
List outcomeOrder = new ArrayList();
String missing = "";
HashSet outcomes = Sets
.newHashSet(this.network.getOutcomeIds(nodeId));
for (IConcept c : cls.getConceptOrder()) {
String id = c.getLocalName();
if (!outcomes.contains(id)) {
missing += (missing.isEmpty() ? "" : ", ") + id;
}
outcomeOrder.add(id);
}
if (!missing.isEmpty()) {
monitor.error("bayesian: cannot find outcome(s): " + missing + " in node "
+ nodeId
+ " to match observable " + observable.getLocalName() + " (" + key
+ ")");
return;
}
/*
* record key for evidence matching
*/
outputKeys.put(nodeId, outcomeOrder);
}
}
private String findMatchingNodeID(IObservableSemantics observable, String key) {
String ret = null;
if (key != null) {
String humpKey = CamelCase.toUpperCamelCase(key, '-');
if (nodeIds.contains(key)) {
ret = key;
} else if (nodeIds.contains(humpKey)) {
ret = humpKey;
}
}
if (ret == null) {
if (observable.getFormalName() != null) {
String humpKey = CamelCase.toUpperCamelCase(key, '-');
for (String nid : nodeIds) {
if (nid.equals(observable.getFormalName()) || nid.equals(humpKey)) {
ret = nid;
break;
}
}
}
}
if (ret == null && ((ObservableSemantics) observable).getLocalType() != null) {
List candidateIds = Observables
.getIdentifiersFor(((ObservableSemantics) observable).getLocalType());
for (String id : candidateIds) {
for (String nid : nodeIds) {
if (nid.equals(id)) {
ret = nid;
break;
}
}
if (ret != null) {
break;
}
}
}
if (ret == null) {
List candidateIds = Observables
.getIdentifiersFor(observable.getType());
for (String id : candidateIds) {
for (String nid : nodeIds) {
if (nid.equals(id)) {
ret = nid;
break;
}
}
if (ret != null) {
break;
}
}
}
if (ret != null) {
key2node.put(key, ret);
concept2node.put(observable.getType(), ret);
}
return ret;
}
private IClassification getClassification(IObserver observer) {
IClassification ret = null;
if (observer instanceof IClassifyingObserver) {
ret = ((IClassifyingObserver) observer).getClassification();
} else if (observer instanceof INumericObserver) {
ret = ((INumericObserver) observer).getDiscretization();
} else if (observer instanceof IConditionalObserver) {
/*
* classifications must be the same if present, so just get the first
* FIXME/CHECK: should use getRepresentativeObserver although I don't think it
* makes a difference now. No test case so postponing.
*/
ret = getClassification(((IConditionalObserver) observer).getModels().get(0)
.getFirst().getObserver());
}
if (ret == null && observer instanceof IMediatingObserver) {
ret = getClassification(((IMediatingObserver) observer)
.getMediatedObserver());
}
return ret;
}
/*
* runs once before the first process(). Resolves all uncertainty keys to the
* correspondent concept.
*/
private void resolveUncertaintyRefs() throws KlabValidationException {
resolved = true;
for (UncertaintyDesc u : _uncertainties.values()) {
for (IKnowledge c : concept2node.keySet()) {
if (u.observed.is(c)) {
u.nodeKey = concept2node.get(c);
break;
}
}
// u.nodeKey = _concept2node.get(u.observed);
if (u.nodeKey == null) {
throw new KlabValidationException("cannot find concept " + u.observed
+ " for uncertainty computation");
}
}
}
public Map run(Map inputs, ITransition transition)
throws KlabException {
ArrayList> evidence = new ArrayList>();
if (!resolved) {
resolveUncertaintyRefs();
}
this.inference.clearEvidence();
Map ret = new HashMap<>();
for (String inputKey : getInputKeys()) {
Object value = inputs.get(inputKey);
if (presenceKeys.containsKey(inputKey)) {
if (!(value instanceof Boolean) && value != null) {
throw new KlabRuntimeException("internal: presence value not a boolean for "
+ inputKey);
}
Triple ik = presenceKeys.get(inputKey);
if (value != null) {
evidence.add(new Pair(ik.getFirst(), (Boolean) value
? ik.getSecond() : ik.getThird()));
}
continue;
}
String nodeId = key2node.get(inputKey);
if (nodeId == null && !warningKeys.contains(inputKey)) {
monitor.warn("model dependency " + inputKey
+ " cannot be matched to any Bayesian node");
warningKeys.add(inputKey);
}
if (value != null && !(value instanceof IConcept)) {
if (value instanceof IProbabilityDistribution) {
IObserver observer = getInputObservers().get(inputKey);
IClassification classification = getClassification(observer);
IProbabilityDistribution distribution = (IProbabilityDistribution) value;
if (classification.getConceptOrder()
.size() == distribution.getData().length) {
if (!typewarnKeys.contains(inputKey)) {
monitor.warn("input " + inputKey
+ " is probabilistic: using most likely category in input as evidence");
typewarnKeys.add(inputKey);
}
IConcept c = classification.getConceptOrder().get(0);
double max = distribution.getData()[0];
for (int i = 0; i < distribution.getData().length; i++) {
if (distribution.getData()[i] > max) {
c = classification.getConceptOrder().get(i);
}
}
value = c;
} else {
if (!typewarnKeys.contains(inputKey)) {
monitor.error("incompatible probabilistic input for node "
+ inputKey);
typewarnKeys.add(inputKey);
}
}
}
if (!typewarnKeys.contains(inputKey)) {
monitor.warn("ignoring non-categorical value for " + inputKey
+ " as Bayesian evidence");
typewarnKeys.add(inputKey);
}
}
/*
* silent if value is not a concept for any reason, but behaves nicely on
* nodata.
*/
if (nodeId != null && value instanceof IConcept) {
evidence.add(new Pair(nodeId, ((IConcept) value)
.getLocalName()));
}
}
/*
* submit evidence
*/
for (Pair zio : evidence) {
this.inference.setEvidence(zio.getFirst(), zio.getSecond());
}
/*
* run inference
*/
this.inference.run();
/*
* put values back
*/
for (String outputKey : getOutputKeys()) {
String nodeId = null;
boolean isUncertainty = false;
if (_uncertainties.containsKey(outputKey)) {
nodeId = _uncertainties.get(outputKey).nodeKey;
isUncertainty = true;
} else {
nodeId = probabilityNode == null ? this.key2node.get(outputKey) : probabilityNode;
}
if (probabilityNode != null) {
ret.put(outputKey, this.inference.getMarginal(probabilityNode, probabilityOutcome));
} else {
List okey = this.outputKeys.get(nodeId);
/*
* happens after initialization gave errors - which should interrupt the
* process, but if not (e.g. parallelizing), this avoids a NPE.
*/
if (nodeId == null || okey == null) {
continue;
}
double[] data = new double[okey.size()];
int i = 0;
for (String outcome : okey) {
data[i++] = this.inference.getMarginal(nodeId, outcome);
}
Object value = new IndexedCategoricalDistribution(data, this.outputClassifications
.get(nodeId).getDistributionBreakpoints());
if (isUncertainty) {
value = ((IndexedCategoricalDistribution) ret).getUncertainty();
}
ret.put(outputKey, value);
}
}
return ret;
}
@Override
public String toString() {
return "Bayesian network " + (this.network == null ? "" : this.network.getName());
}
@Override
public boolean isProbabilistic() {
return probabilityOutcome == null;
}
@Override
public Map initialize(int index, Map inputs)
throws KlabException {
return run(inputs, ITransition.INITIALIZATION);
}
@Override
public Map compute(int index, ITransition transition, Map inputs)
throws KlabException {
return run(inputs, transition);
}
@Override
public String getLabel() {
return "Bayesian inference";
}
}