All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.integratedmodelling.engine.modelling.learning.WEKALearningProcess Maven / Gradle / Ivy

The newest version!
package org.integratedmodelling.engine.modelling.learning;

import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

import org.integratedmodelling.api.knowledge.IConcept;
import org.integratedmodelling.api.metadata.IMetadata;
import org.integratedmodelling.api.modelling.IActiveDirectObservation;
import org.integratedmodelling.api.modelling.IActiveProcess;
import org.integratedmodelling.api.modelling.IClassifyingObserver;
import org.integratedmodelling.api.modelling.IDirectObservation;
import org.integratedmodelling.api.modelling.IEvent;
import org.integratedmodelling.api.modelling.INumericObserver;
import org.integratedmodelling.api.modelling.IObservableSemantics;
import org.integratedmodelling.api.modelling.IObserver;
import org.integratedmodelling.api.modelling.IPresenceObserver;
import org.integratedmodelling.api.modelling.IProbabilityObserver;
import org.integratedmodelling.api.modelling.IState;
import org.integratedmodelling.api.modelling.IState.ChangeListener;
import org.integratedmodelling.api.modelling.ISubject;
import org.integratedmodelling.api.modelling.resolution.IResolutionScope;
import org.integratedmodelling.api.modelling.scheduling.ITransition;
import org.integratedmodelling.api.monitoring.IMonitor;
import org.integratedmodelling.api.monitoring.Messages;
import org.integratedmodelling.collections.Pair;
import org.integratedmodelling.collections.Path;
import org.integratedmodelling.common.states.States;
import org.integratedmodelling.common.vocabulary.NS;
import org.integratedmodelling.exceptions.KlabContextualizationException;
import org.integratedmodelling.exceptions.KlabException;
import org.integratedmodelling.exceptions.KlabIOException;
import org.integratedmodelling.exceptions.KlabInternalErrorException;
import org.integratedmodelling.exceptions.KlabValidationException;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.converters.ArffSaver;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Discretize;

/**
 * Wrapper for a WEKA classifier and instance set to produce, update and manage WEKA
 * classifiers and AIRFF datasets from context information.
 * 
 * @author Ferd
 */
public class WEKALearningProcess {

    class Var {
        IState               state;
        Attribute            attribute;
        Map legend;

        public double getAttributeValue(Object inst) {

            if (inst instanceof IConcept) {
                inst = inst.toString();
            } else if (inst instanceof Boolean) {
                inst = ((Boolean) inst) ? "true" : "false";
            }
            if (legend != null) {
                return legend.get(inst);
            }
            return inst instanceof Number ? ((Number) inst).doubleValue() : 0;
        }
    }

    Var                              predicted                      = null;
    List                        predictors                     = new ArrayList<>();
    Set          archetypes                     = new HashSet<>();
    IState                           distributedArchetype           = null;
    Instances                        instances                      = null;
    IMonitor                         monitor;
    IState                           outputState                    = null;
    boolean                          distributedArchetypeChanged    = false;
    protected Classifier             classifier;

    // options set from API
    List                     options                        = new ArrayList<>();

    /**
     * if false, numeric predictors are automatically discretized
     */
    boolean                          allowNumeric                   = true;
    boolean                          acceptNoData                   = false;
    boolean                          crossValidate                  = true;
    boolean                          skipTraining                   = false;

    /**
     * if false, numeric zero values are ignored in training
     */
    boolean                          ignoreZeroes                   = true;

    /**
     * if true and archetypes are direct observations, one value per archetype is
     * generated, aggregating the predictors over its scale.
     */
    boolean                          aggregateArchetype             = true;
    private IActiveProcess           learningProcess;
    private IActiveDirectObservation context;

    Map              states;
    private IResolutionScope         resolutionScope;
    protected int                    MIN_INSTANCES_FOR_TRAINING     = 5;
    private boolean                  forceStaticOutput              = false;
    private double                   PRESENCE_PROBABILITY_THRESHOLD = 0.9;
    private boolean                  optionsSet;
    protected Discretize             discretizer = null;

    public WEKALearningProcess(Classifier classifier, IMonitor monitor) {
        this.monitor = monitor;
        this.classifier = classifier;
    }

    public void addWekaOptions(String... options) {
        if (options != null) {
            for (String o : options) {
                this.options.add(o);
            }
        }
    }
    
    public Discretize getDiscretizer() {
        return discretizer;
    }
    
    public Classifier getClassifier() {
        return classifier;
    }

    /**
     * Call to force the output to be static.
     */
    public void forceStaticOutput() {
        forceStaticOutput = true;
    }

    /**
     * Set the cross-validation flag (default true).
     * 
     * @param b
     */
    public void setCrossValidation(boolean b) {
        crossValidate = b;
    }

    /**
     * Set whether the training phase should be skipped (default false, unless the trained
     * model was read from a file).
     * 
     * @param b
     */
    public void skipTraining(boolean b) {
        skipTraining = b;
    }

    /**
     * Set whether instances need to be discretized before training.
     * 
     * @param b
     */
    public void setNumericInputAllowed(boolean b) {
        this.allowNumeric = b;
    }

    /**
     * Set the minimum number of instances required for successful training. Default is 5.
     * 
     * @param n
     */
    public void setMinimumInstanceCount(int n) {
        this.MIN_INSTANCES_FOR_TRAINING = n;
    }

    /**
     * Set the minimum probability of the 'true' case for a presence to be accepted.
     * Default is 0.9.
     * 
     * @param p
     */
    public void setPresenceProbabilityThreshold(double p) {
        this.PRESENCE_PROBABILITY_THRESHOLD = p;
    }

    /**
     * Produce an instance set for training according to the learning roles in the inputs.
     * Archetypes are found either in the input data (presence) or in the subjects in the
     * context.
     * 
     * @param learningProcess
     *            the learning process being computed
     * @param context
     *            the context of the process
     * @param resolutionScope
     *            resolution scope (may be null)
     * @param inputs
     *            all input observables, either with role "explanatory variable" or
     *            "archetype". Must correspond to existing states in context.
     * @param outputs
     *            all input observables, in which the "learned variable" role will be
     *            looked up.
     * @param allowNumeric
     *            if false, any numeric input is automatically discretized
     * @param monitor
     *            monitor for communication
     * @return
     * @throws KlabException
     */
    public void initialize(IActiveProcess learningProcess, IActiveDirectObservation context, IResolutionScope resolutionScope, Map inputs, Map outputs)
            throws KlabException {

        this.states = States.matchStatesToInputs(context, inputs);

        this.learningProcess = learningProcess;
        this.context = context;
        this.resolutionScope = resolutionScope;

        for (String out : outputs.keySet()) {
            if (learningProcess.getRolesFor(outputs.get(out))
                    .contains(NS.LEARNED_QUALITY_ROLE)) {
                /*
                 * make state, set into outputState, set up observer
                 */
                if (this.outputState != null) {
                    throw new KlabValidationException("only one quality can be learned in a learning process");
                }
                this.outputState = forceStaticOutput
                        ? context.getStaticState(outputs.get(out))
                        : context.getState(outputs.get(out));
                this.outputState.getMetadata().put(IMetadata.DC_LABEL, out);
            }
        }

    }

    public IState getOutputState() {
        return outputState;
    }

    private ArrayList getAttributes() {
        ArrayList ret = new ArrayList<>();
        if (predicted != null) {
            ret.add(predicted.attribute);
        }
        for (Var var : predictors) {
            ret.add(var.attribute);
        }
        return ret;
    }

    /*
     * record the match between the passed value of the archetype and all the others.
     */
    private void recordInstanceValue(Object inst, int n) {

        Object instanceValue = null;

        if (predicted.state.getObserver() instanceof IPresenceObserver
                || predicted.state.getObserver() instanceof IProbabilityObserver) {
            if (inst instanceof Boolean && ((Boolean) inst)) {
                instanceValue = "true";
            }
        } else if (predicted.state.getObserver() instanceof INumericObserver) {
            instanceValue = ((Number) inst).doubleValue();
        } else if (predicted.state.getObserver() instanceof IClassifyingObserver) {
            instanceValue = ((IConcept) inst).toString();
        }

        if (instanceValue != null) {

            double[] values = new double[predictors.size() + 1];

            values[0] = predicted.getAttributeValue(inst);

            /*
             * add other predictors; if all values are not nodata or nodata are allowed
             * and at least one predictor is not nodata, add instance
             */
            int nodata = 0;
            for (int i = 0; i < predictors.size(); i++) {

                Object value = States.get(predictors.get(i).state, n);
                if (value == null
                        || (value instanceof Number
                                && Double.isNaN(((Number) value).doubleValue()))) {
                    nodata++;
                    values[i + 1] = Double.NaN;
                } else {
                    values[i + 1] = predictors.get(i).getAttributeValue(value);
                }
            }

            if (nodata == 0 || (acceptNoData && values.length > (nodata + 1))) {
                instances.add(new DenseInstance(1.0, values));
            }

        }

    }

    /*
     * record values of explained variables at the covered extent of the archetype
     * observation.
     */
    private void createInstance(IDirectObservation o) {

        if (this.aggregateArchetype) {

        } else {

        }

    }

    private Var getPredictor(IState o) throws KlabValidationException {

        Var ret = null;
        IObserver obsrv = o.getObserver();
        if (obsrv instanceof INumericObserver) {
            ret = new Var();
            ret.state = o;
            ret.attribute = new Attribute(o.getObserver().getObservable()
                    .getFormalName());
        } else if (obsrv instanceof IClassifyingObserver) {

            ret = new Var();
            ret.state = o;
            ArrayList nominalValues = new ArrayList<>();
            for (IConcept c : ((IClassifyingObserver) obsrv).getClassification()
                    .getConceptOrder()) {
                nominalValues.add(c.toString());
            }
            ret.attribute = new Attribute(o.getObserver().getObservable()
                    .getFormalName(), nominalValues);
            ret.legend = new HashMap<>();
            for (int i = 0; i < nominalValues.size(); i++) {
                ret.legend.put(nominalValues.get(i), i);
            }

        } else if (obsrv instanceof IPresenceObserver) {

            ret = new Var();
            ret.state = o;
            ArrayList nominalValues = new ArrayList<>();
            nominalValues.add("true");
            nominalValues.add("false");
            ret.attribute = new Attribute(o.getObserver().getObservable()
                    .getFormalName(), nominalValues);
            ret.legend = new HashMap<>();
            for (int i = 0; i < nominalValues.size(); i++) {
                ret.legend.put(nominalValues.get(i), i);
            }

        } else {
            throw new KlabValidationException("WEKA learning process: occurrence state "
                    + o
                    + " is not numeric, categorical or boolean");
        }
        return ret;
    }

    /**
     * Save the training set to a file. Call after {@link #train(ITransition)} obviously.
     * 
     * @param file
     * @throws KlabException
     */
    public void saveData(File file) throws KlabException {
        saveData(file, this.instances);
    }

    public void saveData(File file, Instances instances) throws KlabException {
        ArffSaver saver = new ArffSaver();
        saver.setInstances(instances);
        try {
            saver.setFile(file);
            // saver.setDestination(file);
            saver.writeBatch();
        } catch (Exception e) {
            throw new KlabIOException(e);
        }
    }
    
    /**
     * Build the instance set and train the model.
     * 
     * @throws KlabException
     */
    Instances train(ITransition transition) throws KlabException {

        /*
         * add any options to classifier if not done already
         */
        if (!optionsSet && this.classifier instanceof OptionHandler
                && options.size() > 0) {

            try {
                ((OptionHandler) this.classifier)
                        .setOptions(options.toArray(new String[options.size()]));
            } catch (Exception e) {
                throw new KlabInternalErrorException(e);
            }

            optionsSet = true;
        }

        /*
         * find strategy to establish archetypes. If an input is tagged as archetype, do
         * not look for subjects or events. Can be run multiple times; the distributed
         * states are only scanned once.
         */

        for (String inp : states.keySet()) {
            IState o = states.get(inp);
            if (learningProcess.getRolesFor(o).contains(NS.ARCHETYPE_ROLE)) {

                predicted = getPredictor(o);

                distributedArchetype = o;
                distributedArchetypeChanged = true;
                distributedArchetype.addChangeListener(new ChangeListener() {

                    @Override
                    public void transitionDone(ITransition transaction) {
                    }

                    @Override
                    public void changed(int offset, Object value) {
                        distributedArchetypeChanged = true;
                    }
                });
            } else if (learningProcess.getRolesFor(o)
                    .contains(NS.EXPLANATORY_QUALITY_ROLE)) {
                Var predictor = getPredictor(o);
                if (predictor != null) {
                    predictors.add(predictor);
                }
            }
        }

        Set newArchetypes = new HashSet<>();

        if (distributedArchetype == null) {
            /*
             * Find yet-unknown archetypes in subject TODO define the predicted attribute
             * before creating instances
             */
            for (ISubject s : ((ISubject) context).getSubjects()) {
                if (learningProcess.getRolesFor(s).contains(NS.ARCHETYPE_ROLE)
                        && !archetypes.contains(s)) {
                    archetypes.add(s);
                    newArchetypes.add(s);
                }
            }
            for (IEvent s : ((ISubject) context).getEvents()) {
                if (learningProcess.getRolesFor(s).contains(NS.ARCHETYPE_ROLE)
                        && !archetypes.contains(s)) {
                    archetypes.add(s);
                    newArchetypes.add(s);
                }
            }
        }

        int capacity = 0;

        this.instances = new Instances(learningProcess.getName()
                + "_instances", getAttributes(), capacity);

        /*
         * follow archetypes and build attribute set.
         */
        if (distributedArchetype != null) {

            /*
             * must run only if the distributed archetype is new or has changed TODO
             * should keep a boolean cache to understand changed values for updateable
             * classifiers.
             */
            if (distributedArchetypeChanged) {
                for (int n : context.getScale().getIndex(transition)) {

                    if (!context.getScale().isCovered(n)) {
                        continue;
                    }

                    Object inst = States.get(distributedArchetype, n);
                    if (ignoreZeroes && inst instanceof Number
                            && ((Number) inst).doubleValue() == 0.0) {
                        continue;
                    }
                    if (!(inst == null
                            || (inst instanceof Number
                                    && Double.isNaN(((Number) inst).doubleValue())))) {
                        recordInstanceValue(inst, n);
                    }
                }
            }

        } else if (newArchetypes.size() > 0) {
            for (IDirectObservation o : newArchetypes) {
                createInstance(o);
            }
        }

        if (this.instances.size() < MIN_INSTANCES_FOR_TRAINING) {
            throw new KlabValidationException("not enough instances for training ("
                    + this.instances.size()
                    + ")");
        }

        monitor.info("generated training set with " + this.instances.size()
                + " instances", Messages.INFOCLASS_MODEL);

        /*
         * train
         */
        monitor.info("training "
                + Path.getLast(classifier.getClass().getCanonicalName(), '.') + " rev "
                + classifier.getCapabilities().getRevision()
                + " ...", Messages.INFOCLASS_MODEL);

        Instances trainingSet = this.instances;
        if (!allowNumeric) {
            // TODO also check if any attributes are numeric
            // TODO report
            this.discretizer = new Discretize();
            try {
                discretizer.setInputFormat(this.instances);
                trainingSet = Filter.useFilter(this.instances, discretizer);
            } catch (Exception e) {
                throw new KlabContextualizationException("discretization failed: check that values have a numeric range");
            }
        }

        trainingSet.setClassIndex(0);

        if (crossValidate) {
            try {

                Evaluation eval = new Evaluation(trainingSet);
                eval.crossValidateModel(this.classifier, trainingSet, 10, new Random(1));

                /*
                 * TODO in report, not in sysout
                 */
                System.out.println(eval.toSummaryString());
                System.out.println(eval.toClassDetailsString());
                System.out.println(eval.toMatrixString());

            } catch (Exception e) {
                throw new KlabContextualizationException(e);
            }
        } else {

        /*
         * go
         */
        doTraining(trainingSet);
        }
        monitor.info("learning completed.", Messages.INFOCLASS_MODEL);

        /*
         * TODO report
         */

        return trainingSet;

    }

    protected void doTraining(Instances instances) throws KlabException {
        try {
            this.classifier.buildClassifier(instances);
        } catch (Exception e) {
            throw new KlabContextualizationException(e);
        }
    }

    public void runModel(ITransition transition, Instances instances)
            throws KlabException {

        Instance instance = new DenseInstance(predictors.size() + 1);
        instance.setDataset(instances);
        instance.setClassMissing();
        
        for (int n : context.getScale().getIndex(transition)) {

            if (!context.getScale().isCovered(n)) {
                States.set(outputState, null);
                continue;
            }
            
            /*
             * add other predictors; if all values are not nodata or nodata are allowed
             * and at least one predictor is not nodata, add instance
             */
            int nodata = 0;

            for (int i = 0; i < predictors.size(); i++) {

                Object value = States.get(predictors.get(i).state, n);
                if (value == null
                        || (value instanceof Number
                                && Double.isNaN(((Number) value).doubleValue()))) {
                    nodata++;
                    instance.setMissing(i + 1);
                } else {
                    instance.setValue(i + 1, predictors.get(i).getAttributeValue(value));
                }
            }
            
            if (discretizer != null) {
                /*
                 * discretize instance
                 */
                discretizer.input(instance);
                instance = discretizer.output();
                instance.dataset().setClassIndex(0);
            }

            if (nodata == 0 || (acceptNoData && predictors.size() > nodata)) {

                try {
                    if (predicted.legend != null) {

                        double pred = classifier.classifyInstance(instance);
                        States.set(outputState, pred, n);

                    } else {

                        
//                        double predictionIndex = classifier.classifyInstance(instance);
//                        String predictedLabel = instances.classAttribute().value((int)predictionIndex);
                        double[] dist = classifier.distributionForInstance(instance);
                        
                        if (outputState.getObserver() instanceof IPresenceObserver) {
                            States.set(outputState, dist[0] > PRESENCE_PROBABILITY_THRESHOLD
                                    ? Boolean.TRUE
                                    : Boolean.FALSE, n);
                        } else if (outputState
                                .getObserver() instanceof IProbabilityObserver) {
                            States.set(outputState, dist[0], n);
                        } else {
                            /*
                             * TODO probability of distribution of outcomes
                             */
                        }

                        // System.out.println(dist + "");
                        // int pred = Utils.maxIndex(dist);
                        // if (dist[(int) pred] <= 0) {
                        // pred = Instance.missingValue();
                        // }
                        // updateStatsForClassifier(dist, instance);
                    }
                } catch (Exception e) {
                    throw new KlabContextualizationException(e);
                }
            } else {
                States.set(outputState, null);
            }
        }
    }

    /*
     * default implementation does not save anything.
     */
    public Pair> saveModel() {
        return null;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy