org.integratedmodelling.engine.modelling.learning.WEKALearningProcess Maven / Gradle / Ivy
The newest version!
package org.integratedmodelling.engine.modelling.learning;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.integratedmodelling.api.knowledge.IConcept;
import org.integratedmodelling.api.metadata.IMetadata;
import org.integratedmodelling.api.modelling.IActiveDirectObservation;
import org.integratedmodelling.api.modelling.IActiveProcess;
import org.integratedmodelling.api.modelling.IClassifyingObserver;
import org.integratedmodelling.api.modelling.IDirectObservation;
import org.integratedmodelling.api.modelling.IEvent;
import org.integratedmodelling.api.modelling.INumericObserver;
import org.integratedmodelling.api.modelling.IObservableSemantics;
import org.integratedmodelling.api.modelling.IObserver;
import org.integratedmodelling.api.modelling.IPresenceObserver;
import org.integratedmodelling.api.modelling.IProbabilityObserver;
import org.integratedmodelling.api.modelling.IState;
import org.integratedmodelling.api.modelling.IState.ChangeListener;
import org.integratedmodelling.api.modelling.ISubject;
import org.integratedmodelling.api.modelling.resolution.IResolutionScope;
import org.integratedmodelling.api.modelling.scheduling.ITransition;
import org.integratedmodelling.api.monitoring.IMonitor;
import org.integratedmodelling.api.monitoring.Messages;
import org.integratedmodelling.collections.Pair;
import org.integratedmodelling.collections.Path;
import org.integratedmodelling.common.states.States;
import org.integratedmodelling.common.vocabulary.NS;
import org.integratedmodelling.exceptions.KlabContextualizationException;
import org.integratedmodelling.exceptions.KlabException;
import org.integratedmodelling.exceptions.KlabIOException;
import org.integratedmodelling.exceptions.KlabInternalErrorException;
import org.integratedmodelling.exceptions.KlabValidationException;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.converters.ArffSaver;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Discretize;
/**
* Wrapper for a WEKA classifier and instance set to produce, update and manage WEKA
* classifiers and AIRFF datasets from context information.
*
* @author Ferd
*/
public class WEKALearningProcess {
class Var {
IState state;
Attribute attribute;
Map legend;
public double getAttributeValue(Object inst) {
if (inst instanceof IConcept) {
inst = inst.toString();
} else if (inst instanceof Boolean) {
inst = ((Boolean) inst) ? "true" : "false";
}
if (legend != null) {
return legend.get(inst);
}
return inst instanceof Number ? ((Number) inst).doubleValue() : 0;
}
}
Var predicted = null;
List predictors = new ArrayList<>();
Set archetypes = new HashSet<>();
IState distributedArchetype = null;
Instances instances = null;
IMonitor monitor;
IState outputState = null;
boolean distributedArchetypeChanged = false;
protected Classifier classifier;
// options set from API
List options = new ArrayList<>();
/**
* if false, numeric predictors are automatically discretized
*/
boolean allowNumeric = true;
boolean acceptNoData = false;
boolean crossValidate = true;
boolean skipTraining = false;
/**
* if false, numeric zero values are ignored in training
*/
boolean ignoreZeroes = true;
/**
* if true and archetypes are direct observations, one value per archetype is
* generated, aggregating the predictors over its scale.
*/
boolean aggregateArchetype = true;
private IActiveProcess learningProcess;
private IActiveDirectObservation context;
Map states;
private IResolutionScope resolutionScope;
protected int MIN_INSTANCES_FOR_TRAINING = 5;
private boolean forceStaticOutput = false;
private double PRESENCE_PROBABILITY_THRESHOLD = 0.9;
private boolean optionsSet;
protected Discretize discretizer = null;
public WEKALearningProcess(Classifier classifier, IMonitor monitor) {
this.monitor = monitor;
this.classifier = classifier;
}
public void addWekaOptions(String... options) {
if (options != null) {
for (String o : options) {
this.options.add(o);
}
}
}
public Discretize getDiscretizer() {
return discretizer;
}
public Classifier getClassifier() {
return classifier;
}
/**
* Call to force the output to be static.
*/
public void forceStaticOutput() {
forceStaticOutput = true;
}
/**
* Set the cross-validation flag (default true).
*
* @param b
*/
public void setCrossValidation(boolean b) {
crossValidate = b;
}
/**
* Set whether the training phase should be skipped (default false, unless the trained
* model was read from a file).
*
* @param b
*/
public void skipTraining(boolean b) {
skipTraining = b;
}
/**
* Set whether instances need to be discretized before training.
*
* @param b
*/
public void setNumericInputAllowed(boolean b) {
this.allowNumeric = b;
}
/**
* Set the minimum number of instances required for successful training. Default is 5.
*
* @param n
*/
public void setMinimumInstanceCount(int n) {
this.MIN_INSTANCES_FOR_TRAINING = n;
}
/**
* Set the minimum probability of the 'true' case for a presence to be accepted.
* Default is 0.9.
*
* @param p
*/
public void setPresenceProbabilityThreshold(double p) {
this.PRESENCE_PROBABILITY_THRESHOLD = p;
}
/**
* Produce an instance set for training according to the learning roles in the inputs.
* Archetypes are found either in the input data (presence) or in the subjects in the
* context.
*
* @param learningProcess
* the learning process being computed
* @param context
* the context of the process
* @param resolutionScope
* resolution scope (may be null)
* @param inputs
* all input observables, either with role "explanatory variable" or
* "archetype". Must correspond to existing states in context.
* @param outputs
* all input observables, in which the "learned variable" role will be
* looked up.
* @param allowNumeric
* if false, any numeric input is automatically discretized
* @param monitor
* monitor for communication
* @return
* @throws KlabException
*/
public void initialize(IActiveProcess learningProcess, IActiveDirectObservation context, IResolutionScope resolutionScope, Map inputs, Map outputs)
throws KlabException {
this.states = States.matchStatesToInputs(context, inputs);
this.learningProcess = learningProcess;
this.context = context;
this.resolutionScope = resolutionScope;
for (String out : outputs.keySet()) {
if (learningProcess.getRolesFor(outputs.get(out))
.contains(NS.LEARNED_QUALITY_ROLE)) {
/*
* make state, set into outputState, set up observer
*/
if (this.outputState != null) {
throw new KlabValidationException("only one quality can be learned in a learning process");
}
this.outputState = forceStaticOutput
? context.getStaticState(outputs.get(out))
: context.getState(outputs.get(out));
this.outputState.getMetadata().put(IMetadata.DC_LABEL, out);
}
}
}
public IState getOutputState() {
return outputState;
}
private ArrayList getAttributes() {
ArrayList ret = new ArrayList<>();
if (predicted != null) {
ret.add(predicted.attribute);
}
for (Var var : predictors) {
ret.add(var.attribute);
}
return ret;
}
/*
* record the match between the passed value of the archetype and all the others.
*/
private void recordInstanceValue(Object inst, int n) {
Object instanceValue = null;
if (predicted.state.getObserver() instanceof IPresenceObserver
|| predicted.state.getObserver() instanceof IProbabilityObserver) {
if (inst instanceof Boolean && ((Boolean) inst)) {
instanceValue = "true";
}
} else if (predicted.state.getObserver() instanceof INumericObserver) {
instanceValue = ((Number) inst).doubleValue();
} else if (predicted.state.getObserver() instanceof IClassifyingObserver) {
instanceValue = ((IConcept) inst).toString();
}
if (instanceValue != null) {
double[] values = new double[predictors.size() + 1];
values[0] = predicted.getAttributeValue(inst);
/*
* add other predictors; if all values are not nodata or nodata are allowed
* and at least one predictor is not nodata, add instance
*/
int nodata = 0;
for (int i = 0; i < predictors.size(); i++) {
Object value = States.get(predictors.get(i).state, n);
if (value == null
|| (value instanceof Number
&& Double.isNaN(((Number) value).doubleValue()))) {
nodata++;
values[i + 1] = Double.NaN;
} else {
values[i + 1] = predictors.get(i).getAttributeValue(value);
}
}
if (nodata == 0 || (acceptNoData && values.length > (nodata + 1))) {
instances.add(new DenseInstance(1.0, values));
}
}
}
/*
* record values of explained variables at the covered extent of the archetype
* observation.
*/
private void createInstance(IDirectObservation o) {
if (this.aggregateArchetype) {
} else {
}
}
private Var getPredictor(IState o) throws KlabValidationException {
Var ret = null;
IObserver obsrv = o.getObserver();
if (obsrv instanceof INumericObserver) {
ret = new Var();
ret.state = o;
ret.attribute = new Attribute(o.getObserver().getObservable()
.getFormalName());
} else if (obsrv instanceof IClassifyingObserver) {
ret = new Var();
ret.state = o;
ArrayList nominalValues = new ArrayList<>();
for (IConcept c : ((IClassifyingObserver) obsrv).getClassification()
.getConceptOrder()) {
nominalValues.add(c.toString());
}
ret.attribute = new Attribute(o.getObserver().getObservable()
.getFormalName(), nominalValues);
ret.legend = new HashMap<>();
for (int i = 0; i < nominalValues.size(); i++) {
ret.legend.put(nominalValues.get(i), i);
}
} else if (obsrv instanceof IPresenceObserver) {
ret = new Var();
ret.state = o;
ArrayList nominalValues = new ArrayList<>();
nominalValues.add("true");
nominalValues.add("false");
ret.attribute = new Attribute(o.getObserver().getObservable()
.getFormalName(), nominalValues);
ret.legend = new HashMap<>();
for (int i = 0; i < nominalValues.size(); i++) {
ret.legend.put(nominalValues.get(i), i);
}
} else {
throw new KlabValidationException("WEKA learning process: occurrence state "
+ o
+ " is not numeric, categorical or boolean");
}
return ret;
}
/**
* Save the training set to a file. Call after {@link #train(ITransition)} obviously.
*
* @param file
* @throws KlabException
*/
public void saveData(File file) throws KlabException {
saveData(file, this.instances);
}
public void saveData(File file, Instances instances) throws KlabException {
ArffSaver saver = new ArffSaver();
saver.setInstances(instances);
try {
saver.setFile(file);
// saver.setDestination(file);
saver.writeBatch();
} catch (Exception e) {
throw new KlabIOException(e);
}
}
/**
* Build the instance set and train the model.
*
* @throws KlabException
*/
Instances train(ITransition transition) throws KlabException {
/*
* add any options to classifier if not done already
*/
if (!optionsSet && this.classifier instanceof OptionHandler
&& options.size() > 0) {
try {
((OptionHandler) this.classifier)
.setOptions(options.toArray(new String[options.size()]));
} catch (Exception e) {
throw new KlabInternalErrorException(e);
}
optionsSet = true;
}
/*
* find strategy to establish archetypes. If an input is tagged as archetype, do
* not look for subjects or events. Can be run multiple times; the distributed
* states are only scanned once.
*/
for (String inp : states.keySet()) {
IState o = states.get(inp);
if (learningProcess.getRolesFor(o).contains(NS.ARCHETYPE_ROLE)) {
predicted = getPredictor(o);
distributedArchetype = o;
distributedArchetypeChanged = true;
distributedArchetype.addChangeListener(new ChangeListener() {
@Override
public void transitionDone(ITransition transaction) {
}
@Override
public void changed(int offset, Object value) {
distributedArchetypeChanged = true;
}
});
} else if (learningProcess.getRolesFor(o)
.contains(NS.EXPLANATORY_QUALITY_ROLE)) {
Var predictor = getPredictor(o);
if (predictor != null) {
predictors.add(predictor);
}
}
}
Set newArchetypes = new HashSet<>();
if (distributedArchetype == null) {
/*
* Find yet-unknown archetypes in subject TODO define the predicted attribute
* before creating instances
*/
for (ISubject s : ((ISubject) context).getSubjects()) {
if (learningProcess.getRolesFor(s).contains(NS.ARCHETYPE_ROLE)
&& !archetypes.contains(s)) {
archetypes.add(s);
newArchetypes.add(s);
}
}
for (IEvent s : ((ISubject) context).getEvents()) {
if (learningProcess.getRolesFor(s).contains(NS.ARCHETYPE_ROLE)
&& !archetypes.contains(s)) {
archetypes.add(s);
newArchetypes.add(s);
}
}
}
int capacity = 0;
this.instances = new Instances(learningProcess.getName()
+ "_instances", getAttributes(), capacity);
/*
* follow archetypes and build attribute set.
*/
if (distributedArchetype != null) {
/*
* must run only if the distributed archetype is new or has changed TODO
* should keep a boolean cache to understand changed values for updateable
* classifiers.
*/
if (distributedArchetypeChanged) {
for (int n : context.getScale().getIndex(transition)) {
if (!context.getScale().isCovered(n)) {
continue;
}
Object inst = States.get(distributedArchetype, n);
if (ignoreZeroes && inst instanceof Number
&& ((Number) inst).doubleValue() == 0.0) {
continue;
}
if (!(inst == null
|| (inst instanceof Number
&& Double.isNaN(((Number) inst).doubleValue())))) {
recordInstanceValue(inst, n);
}
}
}
} else if (newArchetypes.size() > 0) {
for (IDirectObservation o : newArchetypes) {
createInstance(o);
}
}
if (this.instances.size() < MIN_INSTANCES_FOR_TRAINING) {
throw new KlabValidationException("not enough instances for training ("
+ this.instances.size()
+ ")");
}
monitor.info("generated training set with " + this.instances.size()
+ " instances", Messages.INFOCLASS_MODEL);
/*
* train
*/
monitor.info("training "
+ Path.getLast(classifier.getClass().getCanonicalName(), '.') + " rev "
+ classifier.getCapabilities().getRevision()
+ " ...", Messages.INFOCLASS_MODEL);
Instances trainingSet = this.instances;
if (!allowNumeric) {
// TODO also check if any attributes are numeric
// TODO report
this.discretizer = new Discretize();
try {
discretizer.setInputFormat(this.instances);
trainingSet = Filter.useFilter(this.instances, discretizer);
} catch (Exception e) {
throw new KlabContextualizationException("discretization failed: check that values have a numeric range");
}
}
trainingSet.setClassIndex(0);
if (crossValidate) {
try {
Evaluation eval = new Evaluation(trainingSet);
eval.crossValidateModel(this.classifier, trainingSet, 10, new Random(1));
/*
* TODO in report, not in sysout
*/
System.out.println(eval.toSummaryString());
System.out.println(eval.toClassDetailsString());
System.out.println(eval.toMatrixString());
} catch (Exception e) {
throw new KlabContextualizationException(e);
}
} else {
/*
* go
*/
doTraining(trainingSet);
}
monitor.info("learning completed.", Messages.INFOCLASS_MODEL);
/*
* TODO report
*/
return trainingSet;
}
protected void doTraining(Instances instances) throws KlabException {
try {
this.classifier.buildClassifier(instances);
} catch (Exception e) {
throw new KlabContextualizationException(e);
}
}
public void runModel(ITransition transition, Instances instances)
throws KlabException {
Instance instance = new DenseInstance(predictors.size() + 1);
instance.setDataset(instances);
instance.setClassMissing();
for (int n : context.getScale().getIndex(transition)) {
if (!context.getScale().isCovered(n)) {
States.set(outputState, null);
continue;
}
/*
* add other predictors; if all values are not nodata or nodata are allowed
* and at least one predictor is not nodata, add instance
*/
int nodata = 0;
for (int i = 0; i < predictors.size(); i++) {
Object value = States.get(predictors.get(i).state, n);
if (value == null
|| (value instanceof Number
&& Double.isNaN(((Number) value).doubleValue()))) {
nodata++;
instance.setMissing(i + 1);
} else {
instance.setValue(i + 1, predictors.get(i).getAttributeValue(value));
}
}
if (discretizer != null) {
/*
* discretize instance
*/
discretizer.input(instance);
instance = discretizer.output();
instance.dataset().setClassIndex(0);
}
if (nodata == 0 || (acceptNoData && predictors.size() > nodata)) {
try {
if (predicted.legend != null) {
double pred = classifier.classifyInstance(instance);
States.set(outputState, pred, n);
} else {
// double predictionIndex = classifier.classifyInstance(instance);
// String predictedLabel = instances.classAttribute().value((int)predictionIndex);
double[] dist = classifier.distributionForInstance(instance);
if (outputState.getObserver() instanceof IPresenceObserver) {
States.set(outputState, dist[0] > PRESENCE_PROBABILITY_THRESHOLD
? Boolean.TRUE
: Boolean.FALSE, n);
} else if (outputState
.getObserver() instanceof IProbabilityObserver) {
States.set(outputState, dist[0], n);
} else {
/*
* TODO probability of distribution of outcomes
*/
}
// System.out.println(dist + "");
// int pred = Utils.maxIndex(dist);
// if (dist[(int) pred] <= 0) {
// pred = Instance.missingValue();
// }
// updateStatsForClassifier(dist, instance);
}
} catch (Exception e) {
throw new KlabContextualizationException(e);
}
} else {
States.set(outputState, null);
}
}
}
/*
* default implementation does not save anything.
*/
public Pair> saveModel() {
return null;
}
}