weka.classifiers.lazy.IBk Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* IBk.java
* Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.lazy;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.neighboursearch.LinearNNSearch;
import weka.core.neighboursearch.NearestNeighbourSearch;
/**
* K-nearest neighbours classifier. Can select appropriate value of K based on cross-validation. Can also do distance weighting.
*
* For more information, see
*
* D. Aha, D. Kibler (1991). Instance-based learning algorithms. Machine Learning. 6:37-66.
*
*
* BibTeX:
*
* @article{Aha1991,
* author = {D. Aha and D. Kibler},
* journal = {Machine Learning},
* pages = {37-66},
* title = {Instance-based learning algorithms},
* volume = {6},
* year = {1991}
* }
*
*
*
* Valid options are:
*
* -I
* Weight neighbours by the inverse of their distance
* (use when k > 1)
*
* -F
* Weight neighbours by 1 - their distance
* (use when k > 1)
*
* -K <number of neighbors>
* Number of nearest neighbours (k) used in classification.
* (Default = 1)
*
* -E
* Minimise mean squared error rather than mean absolute
* error when using -X option with numeric prediction.
*
* -W <window size>
* Maximum number of training instances maintained.
* Training instances are dropped FIFO. (Default = no window)
*
* -X
* Select the number of nearest neighbours between 1
* and the k value specified using hold-one-out evaluation
* on the training data (use when k > 1)
*
* -A
* The nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch).
*
*
*
* @author Stuart Inglis ([email protected])
* @author Len Trigg ([email protected])
* @author Eibe Frank ([email protected])
* @version $Revision: 10141 $
*/
public class IBk
extends AbstractClassifier
implements OptionHandler, UpdateableClassifier, WeightedInstancesHandler,
TechnicalInformationHandler, AdditionalMeasureProducer {
/** for serialization. */
static final long serialVersionUID = -3080186098777067172L;
/** The training instances used for classification. */
protected Instances m_Train;
/** The number of class values (or 1 if predicting numeric). */
protected int m_NumClasses;
/** The class attribute type. */
protected int m_ClassType;
/** The number of neighbours to use for classification (currently). */
protected int m_kNN;
/**
* The value of kNN provided by the user. This may differ from
* m_kNN if cross-validation is being used.
*/
protected int m_kNNUpper;
/**
* Whether the value of k selected by cross validation has
* been invalidated by a change in the training instances.
*/
protected boolean m_kNNValid;
/**
* The maximum number of training instances allowed. When
* this limit is reached, old training instances are removed,
* so the training data is "windowed". Set to 0 for unlimited
* numbers of instances.
*/
protected int m_WindowSize;
/** Whether the neighbours should be distance-weighted. */
protected int m_DistanceWeighting;
/** Whether to select k by cross validation. */
protected boolean m_CrossValidate;
/**
* Whether to minimise mean squared error rather than mean absolute
* error when cross-validating on numeric prediction tasks.
*/
protected boolean m_MeanSquared;
/** Default ZeroR model to use when there are no training instances */
protected ZeroR m_defaultModel;
/** no weighting. */
public static final int WEIGHT_NONE = 1;
/** weight by 1/distance. */
public static final int WEIGHT_INVERSE = 2;
/** weight by 1-distance. */
public static final int WEIGHT_SIMILARITY = 4;
/** possible instance weighting methods. */
public static final Tag [] TAGS_WEIGHTING = {
new Tag(WEIGHT_NONE, "No distance weighting"),
new Tag(WEIGHT_INVERSE, "Weight by 1/distance"),
new Tag(WEIGHT_SIMILARITY, "Weight by 1-distance")
};
/** for nearest-neighbor search. */
protected NearestNeighbourSearch m_NNSearch = new LinearNNSearch();
/** The number of attributes the contribute to a prediction. */
protected double m_NumAttributesUsed;
/**
* IBk classifier. Simple instance-based learner that uses the class
* of the nearest k training instances for the class of the test
* instances.
*
* @param k the number of nearest neighbors to use for prediction
*/
public IBk(int k) {
init();
setKNN(k);
}
/**
* IB1 classifer. Instance-based learner. Predicts the class of the
* single nearest training instance for each test instance.
*/
public IBk() {
init();
}
/**
* Returns a string describing classifier.
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "K-nearest neighbours classifier. Can "
+ "select appropriate value of K based on cross-validation. Can also do "
+ "distance weighting.\n\n"
+ "For more information, see\n\n"
+ getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing
* detailed information about the technical background of this class,
* e.g., paper reference or book this class is based on.
*
* @return the technical information about this class
*/
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.ARTICLE);
result.setValue(Field.AUTHOR, "D. Aha and D. Kibler");
result.setValue(Field.YEAR, "1991");
result.setValue(Field.TITLE, "Instance-based learning algorithms");
result.setValue(Field.JOURNAL, "Machine Learning");
result.setValue(Field.VOLUME, "6");
result.setValue(Field.PAGES, "37-66");
return result;
}
/**
* Returns the tip text for this property.
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String KNNTipText() {
return "The number of neighbours to use.";
}
/**
* Set the number of neighbours the learner is to use.
*
* @param k the number of neighbours.
*/
public void setKNN(int k) {
m_kNN = k;
m_kNNUpper = k;
m_kNNValid = false;
}
/**
* Gets the number of neighbours the learner will use.
*
* @return the number of neighbours.
*/
public int getKNN() {
return m_kNN;
}
/**
* Returns the tip text for this property.
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String windowSizeTipText() {
return "Gets the maximum number of instances allowed in the training " +
"pool. The addition of new instances above this value will result " +
"in old instances being removed. A value of 0 signifies no limit " +
"to the number of training instances.";
}
/**
* Gets the maximum number of instances allowed in the training
* pool. The addition of new instances above this value will result
* in old instances being removed. A value of 0 signifies no limit
* to the number of training instances.
*
* @return Value of WindowSize.
*/
public int getWindowSize() {
return m_WindowSize;
}
/**
* Sets the maximum number of instances allowed in the training
* pool. The addition of new instances above this value will result
* in old instances being removed. A value of 0 signifies no limit
* to the number of training instances.
*
* @param newWindowSize Value to assign to WindowSize.
*/
public void setWindowSize(int newWindowSize) {
m_WindowSize = newWindowSize;
}
/**
* Returns the tip text for this property.
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String distanceWeightingTipText() {
return "Gets the distance weighting method used.";
}
/**
* Gets the distance weighting method used. Will be one of
* WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY
*
* @return the distance weighting method used.
*/
public SelectedTag getDistanceWeighting() {
return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING);
}
/**
* Sets the distance weighting method used. Values other than
* WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored.
*
* @param newMethod the distance weighting method to use
*/
public void setDistanceWeighting(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_WEIGHTING) {
m_DistanceWeighting = newMethod.getSelectedTag().getID();
}
}
/**
* Returns the tip text for this property.
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String meanSquaredTipText() {
return "Whether the mean squared error is used rather than mean "
+ "absolute error when doing cross-validation for regression problems.";
}
/**
* Gets whether the mean squared error is used rather than mean
* absolute error when doing cross-validation.
*
* @return true if so.
*/
public boolean getMeanSquared() {
return m_MeanSquared;
}
/**
* Sets whether the mean squared error is used rather than mean
* absolute error when doing cross-validation.
*
* @param newMeanSquared true if so.
*/
public void setMeanSquared(boolean newMeanSquared) {
m_MeanSquared = newMeanSquared;
}
/**
* Returns the tip text for this property.
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String crossValidateTipText() {
return "Whether hold-one-out cross-validation will be used to " +
"select the best k value between 1 and the value specified as " +
"the KNN parameter.";
}
/**
* Gets whether hold-one-out cross-validation will be used
* to select the best k value.
*
* @return true if cross-validation will be used.
*/
public boolean getCrossValidate() {
return m_CrossValidate;
}
/**
* Sets whether hold-one-out cross-validation will be used
* to select the best k value.
*
* @param newCrossValidate true if cross-validation should be used.
*/
public void setCrossValidate(boolean newCrossValidate) {
m_CrossValidate = newCrossValidate;
}
/**
* Returns the tip text for this property.
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String nearestNeighbourSearchAlgorithmTipText() {
return "The nearest neighbour search algorithm to use " +
"(Default: weka.core.neighboursearch.LinearNNSearch).";
}
/**
* Returns the current nearestNeighbourSearch algorithm in use.
* @return the NearestNeighbourSearch algorithm currently in use.
*/
public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() {
return m_NNSearch;
}
/**
* Sets the nearestNeighbourSearch algorithm to be used for finding nearest
* neighbour(s).
* @param nearestNeighbourSearchAlgorithm - The NearestNeighbourSearch class.
*/
public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm) {
m_NNSearch = nearestNeighbourSearchAlgorithm;
}
/**
* Get the number of training instances the classifier is currently using.
*
* @return the number of training instances the classifier is currently using
*/
public int getNumTraining() {
return m_Train.numInstances();
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.NUMERIC_CLASS);
result.enable(Capability.DATE_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
// instances
result.setMinimumNumberInstances(0);
return result;
}
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @throws Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
// can classifier handle the data?
getCapabilities().testWithFail(instances);
// remove instances with missing class
instances = new Instances(instances);
instances.deleteWithMissingClass();
m_NumClasses = instances.numClasses();
m_ClassType = instances.classAttribute().type();
m_Train = new Instances(instances, 0, instances.numInstances());
// Throw away initial instances until within the specified window size
if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {
m_Train = new Instances(m_Train,
m_Train.numInstances()-m_WindowSize,
m_WindowSize);
}
m_NumAttributesUsed = 0.0;
for (int i = 0; i < m_Train.numAttributes(); i++) {
if ((i != m_Train.classIndex()) &&
(m_Train.attribute(i).isNominal() ||
m_Train.attribute(i).isNumeric())) {
m_NumAttributesUsed += 1.0;
}
}
m_NNSearch.setInstances(m_Train);
// Invalidate any currently cross-validation selected k
m_kNNValid = false;
m_defaultModel = new ZeroR();
m_defaultModel.buildClassifier(instances);
}
/**
* Adds the supplied instance to the training set.
*
* @param instance the instance to add
* @throws Exception if instance could not be incorporated
* successfully
*/
public void updateClassifier(Instance instance) throws Exception {
if (m_Train.equalHeaders(instance.dataset()) == false) {
throw new Exception("Incompatible instance types\n" + m_Train.equalHeadersMsg(instance.dataset()));
}
if (instance.classIsMissing()) {
return;
}
m_Train.add(instance);
m_NNSearch.update(instance);
m_kNNValid = false;
if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
boolean deletedInstance=false;
while (m_Train.numInstances() > m_WindowSize) {
m_Train.delete(0);
deletedInstance=true;
}
//rebuild datastructure KDTree currently can't delete
if(deletedInstance==true)
m_NNSearch.setInstances(m_Train);
}
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @throws Exception if an error occurred during the prediction
*/
public double [] distributionForInstance(Instance instance) throws Exception {
if (m_Train.numInstances() == 0) {
//throw new Exception("No training instances!");
return m_defaultModel.distributionForInstance(instance);
}
if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
m_kNNValid = false;
boolean deletedInstance=false;
while (m_Train.numInstances() > m_WindowSize) {
m_Train.delete(0);
}
//rebuild datastructure KDTree currently can't delete
if(deletedInstance==true)
m_NNSearch.setInstances(m_Train);
}
// Select k by cross validation
if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {
crossValidate();
}
m_NNSearch.addInstanceInfo(instance);
Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
double [] distances = m_NNSearch.getDistances();
double [] distribution = makeDistribution( neighbours, distances );
return distribution;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration