All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.edux.ml.knn.KnnClassifier Maven / Gradle / Ivy

The newest version!
package de.edux.ml.knn;

import de.edux.api.Classifier;
import java.util.Arrays;
import java.util.PriorityQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * The {@code KnnClassifier} class provides an implementation of the k-Nearest Neighbors algorithm
 * for classification tasks. It stores the training dataset and predicts the label for new data
 * points based on the majority label of its k-nearest neighbors in the feature space. Distance
 * between data points is computed using the Euclidean distance metric. Optionally, predictions can
 * be weighted by the inverse of the distance to give closer neighbors higher influence.
 *
 * 

Example usage: * *

{@code
 * int k = 3;  // Specify the number of neighbors to consider
 * KnnClassifier knn = new KnnClassifier(k);
 * knn.train(trainingFeatures, trainingLabels);
 *
 * double[] prediction = knn.predict(inputFeatures);
 * double accuracy = knn.evaluate(testFeatures, testLabels);
 * }
* *

Note: The label arrays should be in one-hot encoding format. */ public class KnnClassifier implements Classifier { private static final double EPSILON = 1e-10; Logger LOG = LoggerFactory.getLogger(KnnClassifier.class); private double[][] trainFeatures; private double[][] trainLabels; private int k; /** * Initializes a new instance of {@code KnnClassifier} with specified k. * * @param k an integer value representing the number of neighbors to consider during * classification * @throws IllegalArgumentException if k is not a positive integer */ public KnnClassifier(int k) { if (k <= 0) { throw new IllegalArgumentException("k must be a positive integer"); } this.k = k; } @Override public boolean train(double[][] features, double[][] labels) { if (features.length == 0 || features.length != labels.length) { return false; } this.trainFeatures = features; this.trainLabels = labels; return true; } @Override public double evaluate(double[][] testInputs, double[][] testTargets) { LOG.info("Evaluating..."); int correct = 0; for (int i = 0; i < testInputs.length; i++) { double[] prediction = predict(testInputs[i]); if (Arrays.equals(prediction, testTargets[i])) { correct++; } } double accuracy = (double) correct / testInputs.length; LOG.info("KNN - Accuracy: " + accuracy * 100 + "%"); return accuracy; } @Override public double[] predict(double[] feature) { PriorityQueue pq = new PriorityQueue<>((a, b) -> Double.compare(b.distance, a.distance)); for (int i = 0; i < trainFeatures.length; i++) { double distance = calculateDistance(trainFeatures[i], feature); pq.offer(new Neighbor(distance, trainLabels[i])); if (pq.size() > k) { pq.poll(); } } double[] aggregatedLabel = new double[trainLabels[0].length]; double totalWeight = 0; for (Neighbor neighbor : pq) { double weight = 1 / (neighbor.distance + EPSILON); for (int i = 0; i < aggregatedLabel.length; i++) { aggregatedLabel[i] += neighbor.label[i] * weight; } totalWeight += weight; } for (int i = 0; i < aggregatedLabel.length; i++) { aggregatedLabel[i] /= totalWeight; } return aggregatedLabel; } private double calculateDistance(double[] a, double[] b) { double sum = 0; for (int i = 0; i < a.length; i++) { sum += Math.pow(a[i] - b[i], 2); } return Math.sqrt(sum); } private static class Neighbor { private double distance; private double[] label; public Neighbor(double distance, double[] label) { this.distance = distance; this.label = label; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy