All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.classifiers.ThreadedKNNClassifier Maven / Gradle / Ivy

The newest version!
package ml.classifiers;

import datastructs.I2DDataSet;
import datastructs.IVector;
import maths.functions.distances.DistanceCalculator;
import parallel.partitioners.IPartitionPolicy;
import parallel.tasks.TaskBase;
import ml.classifiers.utils.ClassificationVoter;
import utils.Pair;
import utils.PairBuilder;

import java.util.*;
import java.util.concurrent.*;

public class ThreadedKNNClassifier>,
                                   DistanceType extends DistanceCalculator,
                                   VoterType extends ClassificationVoter> extends KNNClassifier {

    /**
     * Constructor
     *
     * @param k The number of iterations
     * @param copyDataset The copied data set
     * @param executorService The service executed
     */
    public ThreadedKNNClassifier(int k, boolean copyDataset, ExecutorService executorService){
        super(k, copyDataset);
        this.executorService = executorService;
    }

    /**
     * Predict the class of the given data point. This class blocks until all
     * computations are completed
     */

    @Override
    public  Integer  predict(PointType point){

        if(this.majorityVoter == null){
            throw new IllegalStateException(" Majority voter has not been set");
        }

        if(this.distanceCalculator == null){
            throw new IllegalStateException("Distance calculator has not been set");
        }

        if(this.dataSet.getPartitionPolicy().numPartitions() == 0){
            throw new IllegalStateException("Dataset does not have partitions set");
        }

        if(this.tasks == null){
            // generate tasks as many as the partitions of the dataset
            this.tasks = new ArrayList<>(this.dataSet.getPartitionPolicy().numPartitions());

        }

        CountDownLatch countDownLatch = new CountDownLatch(this.dataSet.getPartitionPolicy().numPartitions());

        // let's create the tasks and add them to the List
        for (int i = 0; i < this.dataSet.getPartitionPolicy().numPartitions(); i++) {

            this.tasks.add(new KNNTask(i, point, this.dataSet, this.distanceCalculator, countDownLatch));
            this.executorService.submit((Callable>>)this.tasks.get(i));
        }

        // the main thread waits here
        try {
            countDownLatch.await();
        }
        catch(InterruptedException e){
        }

        // all tasks have finished let's collect the distances
        for (int t = 0; t < this.tasks.size(); t++) {

            List> taskResult = ((KNNTask)this.tasks.get(t)).getResult();
            this.majorityVoter.addItems(taskResult);
        }

        return this.getTopResult();
    }

    /**
     * The object responsible for executing the KNN
     */
    private ExecutorService executorService;

    /**
     * Private list that holds the tasks to submit
     */
    private List tasks;

    /**
     * The class that represents the task to submit to the executor
     *
     */
    private class KNNTask extends TaskBase>>
    {

        /**
         * Constructor
         */
        public KNNTask(int taskId, PointType point, DataSetType dataSet, DistanceType distanceCalculator,   CountDownLatch countDownLatch){

            this.setTaskId(taskId);
            this.setResult(new ArrayList>());
            this.point = point;
            this.dataSet = dataSet;
            this.distanceCalculator = distanceCalculator;
            this.countDownLatch = countDownLatch;
        }

        @Override
        public void run(){

            // loop over the items in the dataset and compute distances
            // we don't want to loop over all rows but only to the
            // rows attached to the task. This is implicitly known
            // by the partitioning of the data set
            IPartitionPolicy partitionePolicy = this.dataSet.getPartitionPolicy();
            List rows = partitionePolicy.getParition(this.getTaskId());
            List> result = this.getResult();

            for (int i = 0; i < rows.size(); i++) {

                result.add(PairBuilder.makePair(rows.get(i), this.distanceCalculator.calculate(this.dataSet.getRow(rows.get(i)), point)));
            }

            this.setFinished( true );
            this.countDownLatch.countDown();
        }


        /**
         * The point type the task is working on
         */
        PointType point;

        /**
         * The dataset the task is working on
         */
        private DataSetType dataSet;

        /**
         * The distance used
         */
        private DistanceType distanceCalculator;

        /**
         * Each task counts down this latch when finished
         */
        CountDownLatch countDownLatch;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy