org.integratedmodelling.engine.modelling.learning.algorithms.WekaBayesNet Maven / Gradle / Ivy
The newest version!
package org.integratedmodelling.engine.modelling.learning.algorithms;
import java.io.File;
import java.io.IOException;
import java.util.Map;
import org.integratedmodelling.api.knowledge.IObservation;
import org.integratedmodelling.api.modelling.IActiveDirectObservation;
import org.integratedmodelling.api.modelling.IActiveProcess;
import org.integratedmodelling.api.modelling.IObservableSemantics;
import org.integratedmodelling.api.modelling.resolution.IResolutionScope;
import org.integratedmodelling.api.monitoring.IMonitor;
import org.integratedmodelling.api.services.annotations.Prototype;
import org.integratedmodelling.common.beans.ModelArtifact;
import org.integratedmodelling.common.beans.responses.LocalExportResponse;
import org.integratedmodelling.common.knowledge.Observation;
import org.integratedmodelling.common.utils.FileUtils;
import org.integratedmodelling.common.vocabulary.NS;
import org.integratedmodelling.engine.modelling.learning.AbstractWEKAProcessContextualizer;
import org.integratedmodelling.engine.modelling.learning.WEKALearningProcess;
import org.integratedmodelling.engine.modelling.runtime.DirectObservation;
import org.integratedmodelling.engine.modelling.runtime.DirectObservation.ArtifactGenerator;
import org.integratedmodelling.exceptions.KlabContextualizationException;
import org.integratedmodelling.exceptions.KlabException;
import org.integratedmodelling.exceptions.KlabRuntimeException;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.net.EditableBayesNet;
import weka.classifiers.bayes.net.estimate.BMAEstimator;
import weka.classifiers.bayes.net.search.global.K2;
import weka.core.Instances;
@Prototype(
id = "weka.bayesnet",
args = {
"# import",
Prototype.TEXT,
"# method",
Prototype.TEXT
},
returnTypes = { NS.PROCESS_CONTEXTUALIZER })
public class WekaBayesNet extends AbstractWEKAProcessContextualizer {
private Instances discretizedInstances;
@Override
protected Classifier getClassifier() {
return new EditableBayesNet();
// return new BayesNet();
}
@Override
public Map initialize(IActiveProcess process, IActiveDirectObservation contextSubject, IResolutionScope resolutionContext, Map expectedInputs, Map expectedOutputs, IMonitor monitor)
throws KlabException {
Map ret = super.initialize(process, contextSubject, resolutionContext, expectedInputs, expectedOutputs, monitor);
/*
* add the learned model to the process
*/
String artifactId = "Learned "
+ NS.getDisplayName(process.getModel().getObservables().get(1).getType())
+ " contextualizer";
((DirectObservation) process).addOutputModel(artifactId, new ArtifactGenerator() {
@Override
public LocalExportResponse generateArtifact() {
LocalExportResponse ret = new LocalExportResponse();
ret.setModel(true);
String bif = ((EditableBayesNet) learningProcess.getClassifier())
.toXMLBIF03();
try {
File outfile = File.createTempFile("bnet", "bif");
FileUtils.writeStringToFile(outfile, bif);
ret.getFiles().add(outfile.toString());
} catch (IOException e) {
throw new KlabRuntimeException(e);
}
// TODO
String modelStatement = "Zop";
// Models
// .generateObjectModelSource(type, "vector(file=\"bn/"
// + filename + "\")", ((VectorOutput)output).getAttributes(), "name");
ret.setModelStatement(modelStatement);
ret.setRelativeExportPath("bn");
return ret;
}
});
monitor.send(new ModelArtifact(artifactId, "model", ((Observation) process)
.getInternalId()));
((DirectObservation) process)
.addOutputDataset("Training dataset (raw)", new ArtifactGenerator() {
@Override
public LocalExportResponse generateArtifact() {
LocalExportResponse ret = new LocalExportResponse();
ret.setModel(false);
try {
File outfile = File.createTempFile("bnet", "bif");
learningProcess.saveData(outfile);
ret.getFiles().add(outfile.toString());
} catch (Exception e) {
throw new KlabRuntimeException(e);
}
return ret;
}
});
monitor.send(new ModelArtifact("Training dataset (raw)", "dataset", ((Observation) process)
.getInternalId()));
((DirectObservation) process)
.addOutputDataset("Training dataset (discretized)", new ArtifactGenerator() {
@Override
public LocalExportResponse generateArtifact() {
LocalExportResponse ret = new LocalExportResponse();
ret.setModel(false);
try {
File outfile = File.createTempFile("bnet", "bif");
learningProcess.saveData(outfile, discretizedInstances);
ret.getFiles().add(outfile.toString());
} catch (Exception e) {
throw new KlabRuntimeException(e);
}
return ret;
}
});
monitor.send(new ModelArtifact("Training dataset (discretized)", "dataset", ((Observation) process)
.getInternalId()));
return ret;
}
@Override
protected WEKALearningProcess createLearningProcess(Classifier classifier, IMonitor monitor) {
return new WEKALearningProcess(classifier, monitor) {
@Override
public void initialize(IActiveProcess learningProcess, IActiveDirectObservation context, IResolutionScope resolutionScope, Map inputs, Map outputs)
throws KlabException {
this.setNumericInputAllowed(false);
super.initialize(learningProcess, context, resolutionScope, inputs, outputs);
}
// @Override
// protected void doTraining(Instances instances) throws KlabException {
//
// discretizedInstances = instances;
//
// /*
// * TODO configure methods and estimators
// */
// K2 learner = new K2();
//
// BMAEstimator estimator = new BMAEstimator();
// estimator.setUseK2Prior(true);
//
// try {
//
// ((EditableBayesNet) classifier).buildClassifier(instances);
// ((EditableBayesNet) classifier).initStructure();
// learner.buildStructure(((EditableBayesNet) classifier), instances);
// estimator.estimateCPTs(((EditableBayesNet) classifier));
//
// } catch (Exception e) {
// throw new KlabContextualizationException(e);
// }
// }
};
}
}