All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.integratedmodelling.engine.modelling.learning.algorithms.WekaBayesNet Maven / Gradle / Ivy

The newest version!
package org.integratedmodelling.engine.modelling.learning.algorithms;

import java.io.File;
import java.io.IOException;
import java.util.Map;

import org.integratedmodelling.api.knowledge.IObservation;
import org.integratedmodelling.api.modelling.IActiveDirectObservation;
import org.integratedmodelling.api.modelling.IActiveProcess;
import org.integratedmodelling.api.modelling.IObservableSemantics;
import org.integratedmodelling.api.modelling.resolution.IResolutionScope;
import org.integratedmodelling.api.monitoring.IMonitor;
import org.integratedmodelling.api.services.annotations.Prototype;
import org.integratedmodelling.common.beans.ModelArtifact;
import org.integratedmodelling.common.beans.responses.LocalExportResponse;
import org.integratedmodelling.common.knowledge.Observation;
import org.integratedmodelling.common.utils.FileUtils;
import org.integratedmodelling.common.vocabulary.NS;
import org.integratedmodelling.engine.modelling.learning.AbstractWEKAProcessContextualizer;
import org.integratedmodelling.engine.modelling.learning.WEKALearningProcess;
import org.integratedmodelling.engine.modelling.runtime.DirectObservation;
import org.integratedmodelling.engine.modelling.runtime.DirectObservation.ArtifactGenerator;
import org.integratedmodelling.exceptions.KlabContextualizationException;
import org.integratedmodelling.exceptions.KlabException;
import org.integratedmodelling.exceptions.KlabRuntimeException;

import weka.classifiers.Classifier;
import weka.classifiers.bayes.net.EditableBayesNet;
import weka.classifiers.bayes.net.estimate.BMAEstimator;
import weka.classifiers.bayes.net.search.global.K2;
import weka.core.Instances;

@Prototype(
        id = "weka.bayesnet",
        args = {
                "# import",
                Prototype.TEXT,
                "# method",
                Prototype.TEXT
        },
        returnTypes = { NS.PROCESS_CONTEXTUALIZER })
public class WekaBayesNet extends AbstractWEKAProcessContextualizer {

    private Instances discretizedInstances;

    @Override
    protected Classifier getClassifier() {
        return new EditableBayesNet();
        // return new BayesNet();
    }

    @Override
    public Map initialize(IActiveProcess process, IActiveDirectObservation contextSubject, IResolutionScope resolutionContext, Map expectedInputs, Map expectedOutputs, IMonitor monitor)
            throws KlabException {

        Map ret = super.initialize(process, contextSubject, resolutionContext, expectedInputs, expectedOutputs, monitor);

        /*
         * add the learned model to the process
         */
        String artifactId = "Learned "
                + NS.getDisplayName(process.getModel().getObservables().get(1).getType())
                + " contextualizer";
        ((DirectObservation) process).addOutputModel(artifactId, new ArtifactGenerator() {

            @Override
            public LocalExportResponse generateArtifact() {
                LocalExportResponse ret = new LocalExportResponse();

                ret.setModel(true);
                String bif = ((EditableBayesNet) learningProcess.getClassifier())
                        .toXMLBIF03();
                try {
                    File outfile = File.createTempFile("bnet", "bif");
                    FileUtils.writeStringToFile(outfile, bif);
                    ret.getFiles().add(outfile.toString());
                } catch (IOException e) {
                    throw new KlabRuntimeException(e);
                }

                // TODO
                String modelStatement = "Zop";
                // Models
                // .generateObjectModelSource(type, "vector(file=\"bn/"
                // + filename + "\")", ((VectorOutput)output).getAttributes(), "name");
                ret.setModelStatement(modelStatement);
                ret.setRelativeExportPath("bn");
                return ret;
            }
        });

        monitor.send(new ModelArtifact(artifactId, "model", ((Observation) process)
                .getInternalId()));

        ((DirectObservation) process)
                .addOutputDataset("Training dataset (raw)", new ArtifactGenerator() {

                    @Override
                    public LocalExportResponse generateArtifact() {
                        LocalExportResponse ret = new LocalExportResponse();
                        ret.setModel(false);
                        try {
                            File outfile = File.createTempFile("bnet", "bif");
                            learningProcess.saveData(outfile);
                            ret.getFiles().add(outfile.toString());
                        } catch (Exception e) {
                            throw new KlabRuntimeException(e);
                        }
                        return ret;
                    }
                });

        monitor.send(new ModelArtifact("Training dataset (raw)", "dataset", ((Observation) process)
                .getInternalId()));

        ((DirectObservation) process)
                .addOutputDataset("Training dataset (discretized)", new ArtifactGenerator() {

                    @Override
                    public LocalExportResponse generateArtifact() {
                        LocalExportResponse ret = new LocalExportResponse();
                        ret.setModel(false);
                        try {
                            File outfile = File.createTempFile("bnet", "bif");
                            learningProcess.saveData(outfile, discretizedInstances);
                            ret.getFiles().add(outfile.toString());
                        } catch (Exception e) {
                            throw new KlabRuntimeException(e);
                        }
                        return ret;
                    }
                });

        monitor.send(new ModelArtifact("Training dataset (discretized)", "dataset", ((Observation) process)
                .getInternalId()));

        return ret;
    }

    @Override
    protected WEKALearningProcess createLearningProcess(Classifier classifier, IMonitor monitor) {

        return new WEKALearningProcess(classifier, monitor) {

            @Override
            public void initialize(IActiveProcess learningProcess, IActiveDirectObservation context, IResolutionScope resolutionScope, Map inputs, Map outputs)
                    throws KlabException {
                this.setNumericInputAllowed(false);
                super.initialize(learningProcess, context, resolutionScope, inputs, outputs);
            }

//            @Override
//            protected void doTraining(Instances instances) throws KlabException {
//
//                discretizedInstances = instances;
//
//                /*
//                 * TODO configure methods and estimators
//                 */
//                K2 learner = new K2();
//
//                BMAEstimator estimator = new BMAEstimator();
//                estimator.setUseK2Prior(true);
//
//                try {
//
//                    ((EditableBayesNet) classifier).buildClassifier(instances);
//                    ((EditableBayesNet) classifier).initStructure();
//                    learner.buildStructure(((EditableBayesNet) classifier), instances);
//                    estimator.estimateCPTs(((EditableBayesNet) classifier));
//
//                } catch (Exception e) {
//                    throw new KlabContextualizationException(e);
//                }
//            }
        };
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy