All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.lazy.IBk Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    IBk.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.lazy;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.neighboursearch.LinearNNSearch;
import weka.core.neighboursearch.NearestNeighbourSearch;

/**
 
 * K-nearest neighbours classifier. Can select appropriate value of K based on cross-validation. Can also do distance weighting.
*
* For more information, see
*
* D. Aha, D. Kibler (1991). Instance-based learning algorithms. Machine Learning. 6:37-66. *

* * BibTeX: *

 * @article{Aha1991,
 *    author = {D. Aha and D. Kibler},
 *    journal = {Machine Learning},
 *    pages = {37-66},
 *    title = {Instance-based learning algorithms},
 *    volume = {6},
 *    year = {1991}
 * }
 * 
*

* * Valid options are:

* *

 -I
 *  Weight neighbours by the inverse of their distance
 *  (use when k > 1)
* *
 -F
 *  Weight neighbours by 1 - their distance
 *  (use when k > 1)
* *
 -K <number of neighbors>
 *  Number of nearest neighbours (k) used in classification.
 *  (Default = 1)
* *
 -E
 *  Minimise mean squared error rather than mean absolute
 *  error when using -X option with numeric prediction.
* *
 -W <window size>
 *  Maximum number of training instances maintained.
 *  Training instances are dropped FIFO. (Default = no window)
* *
 -X
 *  Select the number of nearest neighbours between 1
 *  and the k value specified using hold-one-out evaluation
 *  on the training data (use when k > 1)
* *
 -A
 *  The nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch).
 * 
* * * @author Stuart Inglis ([email protected]) * @author Len Trigg ([email protected]) * @author Eibe Frank ([email protected]) * @version $Revision: 14826 $ */ public class IBk extends AbstractClassifier implements OptionHandler, UpdateableClassifier, WeightedInstancesHandler, TechnicalInformationHandler, AdditionalMeasureProducer { /** for serialization. */ static final long serialVersionUID = -3080186098777067172L; /** The training instances used for classification. */ protected Instances m_Train; /** The number of class values (or 1 if predicting numeric). */ protected int m_NumClasses; /** The class attribute type. */ protected int m_ClassType; /** The number of neighbours to use for classification (currently). */ protected int m_kNN; /** * The value of kNN provided by the user. This may differ from * m_kNN if cross-validation is being used. */ protected int m_kNNUpper; /** * Whether the value of k selected by cross validation has * been invalidated by a change in the training instances. */ protected boolean m_kNNValid; /** * The maximum number of training instances allowed. When * this limit is reached, old training instances are removed, * so the training data is "windowed". Set to 0 for unlimited * numbers of instances. */ protected int m_WindowSize; /** Whether the neighbours should be distance-weighted. */ protected int m_DistanceWeighting; /** Whether to select k by cross validation. */ protected boolean m_CrossValidate; /** * Whether to minimise mean squared error rather than mean absolute * error when cross-validating on numeric prediction tasks. */ protected boolean m_MeanSquared; /** Default ZeroR model to use when there are no training instances */ protected ZeroR m_defaultModel; /** no weighting. */ public static final int WEIGHT_NONE = 1; /** weight by 1/distance. */ public static final int WEIGHT_INVERSE = 2; /** weight by 1-distance. */ public static final int WEIGHT_SIMILARITY = 4; /** possible instance weighting methods. */ public static final Tag [] TAGS_WEIGHTING = { new Tag(WEIGHT_NONE, "No distance weighting"), new Tag(WEIGHT_INVERSE, "Weight by 1/distance"), new Tag(WEIGHT_SIMILARITY, "Weight by 1-distance") }; /** for nearest-neighbor search. */ protected NearestNeighbourSearch m_NNSearch = new LinearNNSearch(); /** The number of attributes the contribute to a prediction. */ protected double m_NumAttributesUsed; /** * IBk classifier. Simple instance-based learner that uses the class * of the nearest k training instances for the class of the test * instances. * * @param k the number of nearest neighbors to use for prediction */ public IBk(int k) { init(); setKNN(k); } /** * IB1 classifer. Instance-based learner. Predicts the class of the * single nearest training instance for each test instance. */ public IBk() { init(); } /** * Returns a string describing classifier. * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "K-nearest neighbours classifier. Can " + "select appropriate value of K based on cross-validation. Can also do " + "distance weighting.\n\n" + "For more information, see\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.ARTICLE); result.setValue(Field.AUTHOR, "D. Aha and D. Kibler"); result.setValue(Field.YEAR, "1991"); result.setValue(Field.TITLE, "Instance-based learning algorithms"); result.setValue(Field.JOURNAL, "Machine Learning"); result.setValue(Field.VOLUME, "6"); result.setValue(Field.PAGES, "37-66"); return result; } /** * Returns the tip text for this property. * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String KNNTipText() { return "The number of neighbours to use."; } /** * Set the number of neighbours the learner is to use. * * @param k the number of neighbours. */ public void setKNN(int k) { m_kNN = k; m_kNNUpper = k; m_kNNValid = false; } /** * Gets the number of neighbours the learner will use. * * @return the number of neighbours. */ public int getKNN() { return m_kNN; } /** * Returns the tip text for this property. * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String windowSizeTipText() { return "Gets the maximum number of instances allowed in the training " + "pool. The addition of new instances above this value will result " + "in old instances being removed. A value of 0 signifies no limit " + "to the number of training instances."; } /** * Gets the maximum number of instances allowed in the training * pool. The addition of new instances above this value will result * in old instances being removed. A value of 0 signifies no limit * to the number of training instances. * * @return Value of WindowSize. */ public int getWindowSize() { return m_WindowSize; } /** * Sets the maximum number of instances allowed in the training * pool. The addition of new instances above this value will result * in old instances being removed. A value of 0 signifies no limit * to the number of training instances. * * @param newWindowSize Value to assign to WindowSize. */ public void setWindowSize(int newWindowSize) { m_WindowSize = newWindowSize; } /** * Returns the tip text for this property. * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String distanceWeightingTipText() { return "Gets the distance weighting method used."; } /** * Gets the distance weighting method used. Will be one of * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY * * @return the distance weighting method used. */ public SelectedTag getDistanceWeighting() { return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING); } /** * Sets the distance weighting method used. Values other than * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored. * * @param newMethod the distance weighting method to use */ public void setDistanceWeighting(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_WEIGHTING) { m_DistanceWeighting = newMethod.getSelectedTag().getID(); } } /** * Returns the tip text for this property. * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String meanSquaredTipText() { return "Whether the mean squared error is used rather than mean " + "absolute error when doing cross-validation for regression problems."; } /** * Gets whether the mean squared error is used rather than mean * absolute error when doing cross-validation. * * @return true if so. */ public boolean getMeanSquared() { return m_MeanSquared; } /** * Sets whether the mean squared error is used rather than mean * absolute error when doing cross-validation. * * @param newMeanSquared true if so. */ public void setMeanSquared(boolean newMeanSquared) { m_MeanSquared = newMeanSquared; } /** * Returns the tip text for this property. * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String crossValidateTipText() { return "Whether hold-one-out cross-validation will be used to " + "select the best k value between 1 and the value specified as " + "the KNN parameter."; } /** * Gets whether hold-one-out cross-validation will be used * to select the best k value. * * @return true if cross-validation will be used. */ public boolean getCrossValidate() { return m_CrossValidate; } /** * Sets whether hold-one-out cross-validation will be used * to select the best k value. * * @param newCrossValidate true if cross-validation should be used. */ public void setCrossValidate(boolean newCrossValidate) { m_CrossValidate = newCrossValidate; } /** * Returns the tip text for this property. * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String nearestNeighbourSearchAlgorithmTipText() { return "The nearest neighbour search algorithm to use " + "(Default: weka.core.neighboursearch.LinearNNSearch)."; } /** * Returns the current nearestNeighbourSearch algorithm in use. * @return the NearestNeighbourSearch algorithm currently in use. */ public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() { return m_NNSearch; } /** * Sets the nearestNeighbourSearch algorithm to be used for finding nearest * neighbour(s). * @param nearestNeighbourSearchAlgorithm - The NearestNeighbourSearch class. */ public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm) { m_NNSearch = nearestNeighbourSearchAlgorithm; } /** * Get the number of training instances the classifier is currently using. * * @return the number of training instances the classifier is currently using */ public int getNumTraining() { return m_Train.numInstances(); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // instances result.setMinimumNumberInstances(0); return result; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_NumClasses = instances.numClasses(); m_ClassType = instances.classAttribute().type(); m_Train = new Instances(instances, 0, instances.numInstances()); // Throw away initial instances until within the specified window size if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) { m_Train = new Instances(m_Train, m_Train.numInstances()-m_WindowSize, m_WindowSize); } m_NumAttributesUsed = 0.0; for (int i = 0; i < m_Train.numAttributes(); i++) { if ((i != m_Train.classIndex()) && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) { m_NumAttributesUsed += 1.0; } } m_NNSearch.setInstances(m_Train); // Invalidate any currently cross-validation selected k m_kNNValid = false; m_defaultModel = new ZeroR(); m_defaultModel.buildClassifier(instances); } /** * Adds the supplied instance to the training set. * * @param instance the instance to add * @throws Exception if instance could not be incorporated * successfully */ public void updateClassifier(Instance instance) throws Exception { if (m_Train.equalHeaders(instance.dataset()) == false) { throw new Exception("Incompatible instance types\n" + m_Train.equalHeadersMsg(instance.dataset())); } if (instance.classIsMissing()) { return; } m_Train.add(instance); m_NNSearch.update(instance); m_kNNValid = false; if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { boolean deletedInstance=false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); deletedInstance=true; } //rebuild datastructure KDTree currently can't delete if(deletedInstance==true) m_NNSearch.setInstances(m_Train); } } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if an error occurred during the prediction */ public double [] distributionForInstance(Instance instance) throws Exception { if (m_Train.numInstances() == 0) { //throw new Exception("No training instances!"); return m_defaultModel.distributionForInstance(instance); } if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { m_kNNValid = false; boolean deletedInstance=false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } //rebuild datastructure KDTree currently can't delete if(deletedInstance==true) m_NNSearch.setInstances(m_Train); } // Select k by cross validation if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) { crossValidate(); } m_NNSearch.addInstanceInfo(instance); Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); double [] distances = m_NNSearch.getDistances(); double [] distribution = makeDistribution( neighbours, distances ); return distribution; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration




© 2015 - 2024 Weber Informatics LLC | Privacy Policy