All Downloads are FREE. Search and download functionalities are using the official Maven repository.

moa.classifiers.meta.TemporallyAugmentedClassifier Maven / Gradle / Ivy

Go to download

Massive On-line Analysis is an environment for massive data mining. MOA provides a framework for data stream mining and includes tools for evaluation and a collection of machine learning algorithms. Related to the WEKA project, also written in Java, while scaling to more demanding problems.

There is a newer version: 2024.07.0
Show newest version
/*
 *    TemporallyAugmentedClassifier.java
 *    Copyright (C) 2013 University of Waikato, Hamilton, New Zealand
 *    @author Bernhard Pfahringer ([email protected])
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program. If not, see .
 *    
 */
package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

import com.yahoo.labs.samoa.instances.*;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;

/**
 * Include labels of previous instances into the training data
 *
 * 

This enables a classifier to exploit potentially present auto-correlation *

* *

Parameters:

  • -l : Classifier to train
  • -n : The number * of old labels to include
* * @author Bernhard Pfahringer ([email protected]) * @version $Revision: 1 $ */ public class TemporallyAugmentedClassifier extends AbstractClassifier implements MultiClassClassifier { @Override public String getPurposeString() { return "Add some old labels to every instance"; } private static final long serialVersionUID = 1L; public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree"); public IntOption numOldLabelsOption = new IntOption("numOldLabels", 'n', "The number of old labels to add to each example.", 1, 0, Integer.MAX_VALUE); protected Classifier baseLearner; protected double[] oldLabels; protected Instances header; public FlagOption labelDelayOption = new FlagOption("labelDelay", 'd', "Labels arrive with Delay. Use predictions instead of true Labels."); @Override public void resetLearningImpl() { this.baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); this.oldLabels = new double[this.numOldLabelsOption.getValue()]; this.header = null; baseLearner.resetLearning(); } @Override public void trainOnInstanceImpl(Instance instance) { this.baseLearner.trainOnInstance(extendWithOldLabels(instance)); if (this.labelDelayOption.isSet() == false) { // Use true old Labels to add attributes to instances addOldLabel(instance.classValue()); } } public void addOldLabel(double newPrediction) { int numLabels = this.oldLabels.length; if (numLabels > 0) { for (int i = 1; i < numLabels; i++) { this.oldLabels[i - 1] = this.oldLabels[i]; } this.oldLabels[ numLabels - 1] = newPrediction; } } public void initHeader(Instances dataset) { int numLabels = this.numOldLabelsOption.getValue(); Attribute target = dataset.classAttribute(); List possibleValues = new ArrayList(); int n = target.numValues(); for (int i = 0; i < n; i++) { possibleValues.add(target.value(i)); } ArrayList attrs = new ArrayList(numLabels + dataset.numAttributes()); for (int i = 0; i < numLabels; i++) { attrs.add(new Attribute(target.name() + "_" + i, possibleValues)); } for (int i = 0; i < dataset.numAttributes(); i++) { Attribute attr = dataset.attribute(i); Attribute newAttribute = null; if (attr.isNominal() == true) { newAttribute = new Attribute(attr.name(), attr.getAttributeValues()); } if (attr.isNumeric() == true) { newAttribute = new Attribute(attr.name()); } if (newAttribute != null) { attrs.add(newAttribute); } } this.header = new Instances("extended_" + dataset.getRelationName(), attrs, 0); this.header.setClassIndex(numLabels + dataset.classIndex()); } public Instance extendWithOldLabels(Instance instance) { if (this.header == null) { initHeader(instance.dataset()); this.baseLearner.setModelContext(new InstancesHeader(this.header)); } int numLabels = this.oldLabels.length; if (numLabels == 0) { return instance; } double[] x = instance.toDoubleArray(); double[] x2 = Arrays.copyOfRange(this.oldLabels, 0, numLabels + x.length); System.arraycopy(x, 0, x2, numLabels, x.length); Instance extendedInstance = new DenseInstance(instance.weight(), x2); extendedInstance.setDataset(this.header); //System.out.println( extendedInstance); return extendedInstance; } @Override public double[] getVotesForInstance(Instance instance) { double[] prediction = this.baseLearner.getVotesForInstance(extendWithOldLabels(instance)); if (this.labelDelayOption.isSet() == true) { // Use predicted Labels to add attributes to instances addOldLabel(Utils.maxIndex(prediction)); } return prediction; } @Override public boolean isRandomizable() { return false; // ??? this.baseLearner.isRandomizable; } @Override protected Measurement[] getModelMeasurementsImpl() { List measurementList = new LinkedList(); Measurement[] modelMeasurements = ((AbstractClassifier) this.baseLearner).getModelMeasurements(); if (modelMeasurements != null) { for (Measurement measurement : modelMeasurements) { measurementList.add(measurement); } } return measurementList.toArray(new Measurement[measurementList.size()]); } @Override public void getModelDescription(StringBuilder out, int indent) { // TODO Auto-generated method stub } public String toString() { return "TemporallyAugmentedClassifier using " + this.numOldLabelsOption.getValue() + " labels\n" + this.baseLearner; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy