moa.classifiers.meta.TemporallyAugmentedClassifier Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of moa Show documentation
Show all versions of moa Show documentation
Massive On-line Analysis is an environment for massive data mining. MOA
provides a framework for data stream mining and includes tools for evaluation
and a collection of machine learning algorithms. Related to the WEKA project,
also written in Java, while scaling to more demanding problems.
/*
* TemporallyAugmentedClassifier.java
* Copyright (C) 2013 University of Waikato, Hamilton, New Zealand
* @author Bernhard Pfahringer ([email protected])
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*
*/
package moa.classifiers.meta;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import com.yahoo.labs.samoa.instances.*;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;
/**
* Include labels of previous instances into the training data
*
* This enables a classifier to exploit potentially present auto-correlation
*
*
* Parameters:
- -l : Classifier to train
- -n : The number
* of old labels to include
*
* @author Bernhard Pfahringer ([email protected])
* @version $Revision: 1 $
*/
public class TemporallyAugmentedClassifier extends AbstractClassifier implements MultiClassClassifier {
@Override
public String getPurposeString() {
return "Add some old labels to every instance";
}
private static final long serialVersionUID = 1L;
public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
"Classifier to train.", Classifier.class, "trees.HoeffdingTree");
public IntOption numOldLabelsOption = new IntOption("numOldLabels", 'n',
"The number of old labels to add to each example.", 1, 0, Integer.MAX_VALUE);
protected Classifier baseLearner;
protected double[] oldLabels;
protected Instances header;
public FlagOption labelDelayOption = new FlagOption("labelDelay", 'd',
"Labels arrive with Delay. Use predictions instead of true Labels.");
@Override
public void resetLearningImpl() {
this.baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
this.oldLabels = new double[this.numOldLabelsOption.getValue()];
this.header = null;
baseLearner.resetLearning();
}
@Override
public void trainOnInstanceImpl(Instance instance) {
this.baseLearner.trainOnInstance(extendWithOldLabels(instance));
if (this.labelDelayOption.isSet() == false) {
// Use true old Labels to add attributes to instances
addOldLabel(instance.classValue());
}
}
public void addOldLabel(double newPrediction) {
int numLabels = this.oldLabels.length;
if (numLabels > 0) {
for (int i = 1; i < numLabels; i++) {
this.oldLabels[i - 1] = this.oldLabels[i];
}
this.oldLabels[ numLabels - 1] = newPrediction;
}
}
public void initHeader(Instances dataset) {
int numLabels = this.numOldLabelsOption.getValue();
Attribute target = dataset.classAttribute();
List possibleValues = new ArrayList();
int n = target.numValues();
for (int i = 0; i < n; i++) {
possibleValues.add(target.value(i));
}
ArrayList attrs = new ArrayList(numLabels + dataset.numAttributes());
for (int i = 0; i < numLabels; i++) {
attrs.add(new Attribute(target.name() + "_" + i, possibleValues));
}
for (int i = 0; i < dataset.numAttributes(); i++) {
Attribute attr = dataset.attribute(i);
Attribute newAttribute = null;
if (attr.isNominal() == true) {
newAttribute = new Attribute(attr.name(), attr.getAttributeValues());
}
if (attr.isNumeric() == true) {
newAttribute = new Attribute(attr.name());
}
if (newAttribute != null) {
attrs.add(newAttribute);
}
}
this.header = new Instances("extended_" + dataset.getRelationName(), attrs, 0);
this.header.setClassIndex(numLabels + dataset.classIndex());
}
public Instance extendWithOldLabels(Instance instance) {
if (this.header == null) {
initHeader(instance.dataset());
this.baseLearner.setModelContext(new InstancesHeader(this.header));
}
int numLabels = this.oldLabels.length;
if (numLabels == 0) {
return instance;
}
double[] x = instance.toDoubleArray();
double[] x2 = Arrays.copyOfRange(this.oldLabels, 0, numLabels + x.length);
System.arraycopy(x, 0, x2, numLabels, x.length);
Instance extendedInstance = new DenseInstance(instance.weight(), x2);
extendedInstance.setDataset(this.header);
//System.out.println( extendedInstance);
return extendedInstance;
}
@Override
public double[] getVotesForInstance(Instance instance) {
double[] prediction = this.baseLearner.getVotesForInstance(extendWithOldLabels(instance));
if (this.labelDelayOption.isSet() == true) {
// Use predicted Labels to add attributes to instances
addOldLabel(Utils.maxIndex(prediction));
}
return prediction;
}
@Override
public boolean isRandomizable() {
return false; // ??? this.baseLearner.isRandomizable;
}
@Override
protected Measurement[] getModelMeasurementsImpl() {
List measurementList = new LinkedList();
Measurement[] modelMeasurements = ((AbstractClassifier) this.baseLearner).getModelMeasurements();
if (modelMeasurements != null) {
for (Measurement measurement : modelMeasurements) {
measurementList.add(measurement);
}
}
return measurementList.toArray(new Measurement[measurementList.size()]);
}
@Override
public void getModelDescription(StringBuilder out, int indent) {
// TODO Auto-generated method stub
}
public String toString() {
return "TemporallyAugmentedClassifier using " + this.numOldLabelsOption.getValue() + " labels\n" + this.baseLearner;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy