All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.classifiers.KNNClassifier Maven / Gradle / Ivy

The newest version!
package ml.classifiers;

import datastructs.IVector;
import maths.functions.distances.DistanceCalculator;
import ml.classifiers.utils.ClassificationVoter;
import datastructs.I2DDataSet;
import utils.Pair;

import java.util.*;


/**
 * KNNClassifier performs classification using the KNN algorithm
 */
public class KNNClassifier< DataType, DataSetType extends I2DDataSet>,
                           DistanceType extends DistanceCalculator,
                           VoterType extends ClassificationVoter> {

    /**
     * Constructor
     *
     * @param k The cluster numbers
     * @param copyDataset Boolean for detecting the copied data set
     */
    public KNNClassifier(int k, boolean copyDataset){
        this.k = k;
        this.copyDataset = copyDataset;
    }

    /**
     * How many neighbors the algorithm is using
     * @return The number of neighbours
     */
    public int nNeighbors() {
        return this.k;
    }

    /**
     * Set the object that calculates the distance between instances in the dataset
     * @param distanceCalculator The chosen distance calculator
     */
    public void setDistanceCalculator(DistanceType distanceCalculator){
        this.distanceCalculator = distanceCalculator;
    }

    /**
     * Set the object that calculates the class
     * @param voter Class calculator
     */
    public void setMajorityVoter(VoterType voter){
        this.majorityVoter = voter;
    }

    /**
     * Train the model using the provided data set
     *
     * @param dataSet The given data set
     * @param labels The given labels
     */
    public void train(DataSetType dataSet, List labels){

        if(this.copyDataset){
            this.dataSet = (DataSetType) dataSet.copy();

        }
        else{
            this.dataSet = dataSet;
            this.labels = labels;
        }
    }


    /**
     * Predict the class of the given data point
     *
     * @param  A generic point type
     * @param point The given point
     * @return A point
     */
    public  Integer  predict(PointType point){

        if(this.majorityVoter == null){
            throw new IllegalStateException(" Majority voter has not been set");
        }

        if(this.distanceCalculator == null){
            throw new IllegalStateException("Distance calculator has not been set");
        }

        // loop over the items in the data set and compute distances
        for (int i = 0; i < this.dataSet.m(); i++) {
            this.majorityVoter.addItem(i, this.distanceCalculator.calculate(this.dataSet.getRow(i), point));
        }


        return this.getTopResult(); //maxEntry.getKey();
    }

    protected int getTopResult(){

        // get the top k results
        List> results = this.majorityVoter.getResult(this.k);
        this.majorityVoter.clear();

        Map idxMap = new HashMap<>();

        for(int i=0; i maxEntry = Collections.max(idxMap.entrySet(),
                (Map.Entry e1, Map.Entry e2) -> e1.getValue()
                        .compareTo(e2.getValue()));

        return maxEntry.getKey();
    }


    /**
     * Number of neighbors to consider
     */
    protected int k;

    /**
     * flag indicating whether the dataset should be fully copied
     */
    protected boolean copyDataset;

    /**
     * The dataset
     */
    protected DataSetType dataSet;

    /**
     * The labels
     */
    protected List labels;

    /**
     * The distance used
     */
    DistanceType distanceCalculator;

    /**
     * How to get the majority set
     */
    VoterType majorityVoter;
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy