All Downloads are FREE. Search and download functionalities are using the official Maven repository.

moa.classifiers.meta.RCD Maven / Gradle / Ivy

Go to download

Massive On-line Analysis is an environment for massive data mining. MOA provides a framework for data stream mining and includes tools for evaluation and a collection of machine learning algorithms. Related to the WEKA project, also written in Java, while scaling to more demanding problems.

There is a newer version: 2024.07.0
Show newest version
/*
 *    RCD.java
 *    Copyright (C) 2017 Instituto Federal de Pernambuco
 *    @author Paulo Gonçalves ([email protected])
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program. If not, see .
 *    
 */
package moa.classifiers.meta;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import moa.classifiers.Classifier;
import moa.classifiers.core.statisticaltests.StatisticalTest;
import moa.classifiers.drift.SingleClassifierDrift;
import moa.core.MiscUtils;
import moa.options.ClassOption;

/**
 * Creates a set of classifiers, each one representing a different context.
 * Reuses classifier associating to each one a sample of data and compares new
 * data to old ones using a multivariate non-parametric statistical test. Tests
 * are performed in parallel and classifiers are stored based on their accuracy
 * and stored time.
 *
 * 1) Parameterized number of classifiers to store. 2) Classifiers are stored
 * removing the older ones if the set is full. 3) Classifier with higher
 * significance value is selected.
 *
 * Based on: Gonçalves Jr, Paulo Mauricio, and Roberto Souto Maior De Barros.
 * "RCD: A recurring concept drift framework." Pattern Recognition Letters 34.9
 * (2013): 1018-1025.
 *
 * @author Paulo Goncalves (paulogoncalves at recife dot ifpe dot edu dot br)
 *
 */
public class RCD extends SingleClassifierDrift {

    private static final long serialVersionUID = 1L;

    private class ClassifierKS implements Serializable {

        private final Classifier classifier;
        private final List instances;

        public ClassifierKS(Classifier classifier, List instances) {
            this.classifier = classifier;
            this.instances = instances;
        }

        public Classifier getClassifier() {
            return classifier;
        }

        public List getInstances() {
            return instances;
        }
    }

    public ClassOption statisticalTestOption = new ClassOption("statisticalTest",
            'a', "Non-parametric multivariate statistical test to use.", StatisticalTest.class,
            "KNN");

    public FloatOption similarityBetweenDistributionsOption = new FloatOption(
            "similarityBetweenDistributions",
            's',
            "The minimum percentual similarity between distributions (p-value).",
            0.01, 0, 1);

    public IntOption bufferSizeOption = new IntOption("bufferSize", 'b',
            "The size of the buffer that represents the distributions.", 400,
            1, Integer.MAX_VALUE);

    public IntOption testFrequencyOption = new IntOption("testFrequency",
            't', "In the testing phase, test for best stored classifier after how many instances.",
            400, 1, Integer.MAX_VALUE);

    public IntOption classifiersSizeOption = new IntOption("classifiersSize",
            'c', "The maximum amount of classifiers to store. 0 means unlimited.", 15, 0,
            Integer.MAX_VALUE);

    public IntOption threadSizeOption = new IntOption("threadSize",
            'm', "The thread pool size, indicating how many simultaneous tests are allowed.", 4, 1,
            Integer.MAX_VALUE);

    public IntOption quantityClassifiersTestOption = new IntOption("quantityClassifiersTest",
            'q', "Quantity of identified classifiers to check.", 1, 1,
            Integer.MAX_VALUE);

    private List classifiers;

    protected List currentChunk;

    protected List currentChunk2;

    protected List testChunk;

    protected int bufferSize;

    protected int previousState;

    protected int index;

    @Override
    public void resetLearningImpl() {
        super.resetLearningImpl();
        this.classifiers = new ArrayList();
        this.bufferSize = bufferSizeOption.getValue();
        this.currentChunk = null;
        this.currentChunk2 = null;
        this.testChunk = null;
        this.previousState = Integer.MIN_VALUE;
        this.index = 0;
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        int trueClass = (int) inst.classValue();
        boolean prediction = MiscUtils.maxIndex(this.classifier
                .getVotesForInstance(inst)) == trueClass;
        this.driftDetectionMethod.input(prediction ? 0.0 : 1.0);
        this.ddmLevel = DDM_INCONTROL_LEVEL;
        if (this.driftDetectionMethod.getChange()) {
            this.ddmLevel = DDM_OUTCONTROL_LEVEL;
        }
        if (this.driftDetectionMethod.getWarningZone()) {
            this.ddmLevel = DDM_WARNING_LEVEL;
        }
        switch (this.ddmLevel) {
            case DDM_WARNING_LEVEL:
                this.warningDetected++;
                switch (this.previousState) {
                    case DDM_INCONTROL_LEVEL:
                        this.newclassifier.resetLearning();
                        this.currentChunk2 = new ArrayList();
                        break;
                }
                this.newclassifier.trainOnInstance(inst);
                this.addInstance(this.currentChunk2, inst);
                this.previousState = DDM_WARNING_LEVEL;
                break;
            case DDM_OUTCONTROL_LEVEL:
                this.changeDetected++;
                switch (this.previousState) {
                    case DDM_WARNING_LEVEL:
                        ClassifierKS cs = this.getPreviousClassifier(
                                this.classifier, this.currentChunk2);
                        if (cs == null) {
                            this.classifier = this.newclassifier;
                            this.newclassifier = ((Classifier) getPreparedClassOption(this.baseLearnerOption))
                                    .copy();
                            this.classifiers
                                    .add(new ClassifierKS(
                                            this.classifier, this.currentChunk2));
                            this.currentChunk = this.currentChunk2;
                            int maxSize = this.classifiersSizeOption.getValue();
                            if (this.classifiers.size() > maxSize && maxSize > 0) {
                                this.classifiers.remove(0);
                            }
                        } else {
                            this.classifier = cs.getClassifier();
                            this.currentChunk = cs.getInstances();
                        }
                        this.currentChunk2 = null;
                        this.newclassifier.resetLearning();
                }
                this.previousState = DDM_OUTCONTROL_LEVEL;
                break;
            case DDM_INCONTROL_LEVEL:
                switch (this.previousState) {
                    case DDM_INCONTROL_LEVEL:
                    case DDM_OUTCONTROL_LEVEL:
                        break;
                    case DDM_WARNING_LEVEL:
                        this.currentChunk2 = null;
                        break;
                    default:
                        this.currentChunk = new ArrayList();
                        this.classifiers.add(new ClassifierKS(
                                this.classifier, this.currentChunk));
                        break;
                }
                this.addInstance(this.currentChunk, inst);
                this.previousState = DDM_INCONTROL_LEVEL;
                break;
        }
        this.classifier.trainOnInstance(inst);
    }

    private void addInstance(List instances, Instance instance) {
        if (instances.size() >= bufferSize) {
            instances.remove(0);
        }
        instances.add(instance);
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        if (this.testChunk == null) {
            this.testChunk = new ArrayList();
        }
        this.addInstance(this.testChunk, inst);
        if (this.index++ == testFrequencyOption.getValue()) {
            this.index = 0;
            ClassifierKS cs = this.getPreviousClassifier(
                    this.classifier, this.testChunk);
            if (cs != null) {
                this.classifier = cs.getClassifier();
            }
        }
        return this.classifier.getVotesForInstance(inst);
    }

    /**
     * Searches for the classifier best suited for actual data. All statistical
     * tests are performed in parallel.
     *
     * @param classifier Classifier to be added
     * @param instances Instances used to build the classifier
     * @return
     */
    private ClassifierKS getPreviousClassifier(Classifier classifier,
            List instances) {
        ExecutorService threadPool = Executors.newFixedThreadPool(this.threadSizeOption.getValue());
        int SIZE = this.classifiers.size();
        Map> futures = new HashMap<>();
        for (int i = 0; i < SIZE; i++) {
            ClassifierKS cs = this.classifiers.get(i);
            if (cs != null) {
                if (cs.getClassifier() != classifier) {
                    StatisticalTest st = (StatisticalTest) getPreparedClassOption(this.statisticalTestOption);
                    StatisticalTest temp = (StatisticalTest) st.copy();
                    temp.set(instances, cs.getInstances());
                    futures.put(i, threadPool.submit(temp));
                }
            } else {
                break;
            }
        }
        ClassifierKS cks = null;
        int qtd = this.quantityClassifiersTestOption.getValue();
        double maxPValue = this.similarityBetweenDistributionsOption.getValue();
        try {
            for (int i = 0; i < SIZE && qtd > 0; i++) {
                Future f = futures.get(i);
                if (f != null) {
                    double p = f.get();
                    if (p < maxPValue) {
                        maxPValue = p;
                        cks = this.classifiers.get(i);
                        qtd--;
                    }
                }
            }
        } catch (InterruptedException e) {
            System.out.println("Processing interrupted.");
        } catch (ExecutionException e) {
            throw new RuntimeException("Error computing statistical test.", e);
        }
        threadPool.shutdownNow();
        return cks;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy