ml.classifiers.KNNClassifier Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jstat Show documentation
Show all versions of jstat Show documentation
Java Library for Statistical Analysis.
The newest version!
package ml.classifiers;
import datastructs.IVector;
import maths.functions.distances.DistanceCalculator;
import ml.classifiers.utils.ClassificationVoter;
import datastructs.I2DDataSet;
import utils.Pair;
import java.util.*;
/**
* KNNClassifier performs classification using the KNN algorithm
*/
public class KNNClassifier< DataType, DataSetType extends I2DDataSet>,
DistanceType extends DistanceCalculator,
VoterType extends ClassificationVoter> {
/**
* Constructor
*
* @param k The cluster numbers
* @param copyDataset Boolean for detecting the copied data set
*/
public KNNClassifier(int k, boolean copyDataset){
this.k = k;
this.copyDataset = copyDataset;
}
/**
* How many neighbors the algorithm is using
* @return The number of neighbours
*/
public int nNeighbors() {
return this.k;
}
/**
* Set the object that calculates the distance between instances in the dataset
* @param distanceCalculator The chosen distance calculator
*/
public void setDistanceCalculator(DistanceType distanceCalculator){
this.distanceCalculator = distanceCalculator;
}
/**
* Set the object that calculates the class
* @param voter Class calculator
*/
public void setMajorityVoter(VoterType voter){
this.majorityVoter = voter;
}
/**
* Train the model using the provided data set
*
* @param dataSet The given data set
* @param labels The given labels
*/
public void train(DataSetType dataSet, List labels){
if(this.copyDataset){
this.dataSet = (DataSetType) dataSet.copy();
}
else{
this.dataSet = dataSet;
this.labels = labels;
}
}
/**
* Predict the class of the given data point
*
* @param A generic point type
* @param point The given point
* @return A point
*/
public Integer predict(PointType point){
if(this.majorityVoter == null){
throw new IllegalStateException(" Majority voter has not been set");
}
if(this.distanceCalculator == null){
throw new IllegalStateException("Distance calculator has not been set");
}
// loop over the items in the data set and compute distances
for (int i = 0; i < this.dataSet.m(); i++) {
this.majorityVoter.addItem(i, this.distanceCalculator.calculate(this.dataSet.getRow(i), point));
}
return this.getTopResult(); //maxEntry.getKey();
}
protected int getTopResult(){
// get the top k results
List> results = this.majorityVoter.getResult(this.k);
this.majorityVoter.clear();
Map idxMap = new HashMap<>();
for(int i=0; i maxEntry = Collections.max(idxMap.entrySet(),
(Map.Entry e1, Map.Entry e2) -> e1.getValue()
.compareTo(e2.getValue()));
return maxEntry.getKey();
}
/**
* Number of neighbors to consider
*/
protected int k;
/**
* flag indicating whether the dataset should be fully copied
*/
protected boolean copyDataset;
/**
* The dataset
*/
protected DataSetType dataSet;
/**
* The labels
*/
protected List labels;
/**
* The distance used
*/
DistanceType distanceCalculator;
/**
* How to get the majority set
*/
VoterType majorityVoter;
}