All Downloads are FREE. Search and download functionalities are using the official Maven repository.

examples.ml.example4.Example4 Maven / Gradle / Ivy

package examples.ml.example4;

import datasets.DenseMatrixSet;
import datasets.VectorDouble;
import datastructs.RowBuilder;
import datastructs.RowType;
import maths.functions.distances.EuclideanVectorCalculator;
import ml.classifiers.ThreadedKNNClassifier;
import parallel.partitioners.MatrixRowPartitionPolicy;
import parallel.partitioners.RangePartitioner;
import tech.tablesaw.api.Table;
import tech.tablesaw.columns.Column;
import ml.classifiers.utils.ClassificationVoter;
import utils.Pair;
import utils.PairBuilder;
import utils.TableDataSetLoader;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;

import static java.util.concurrent.Executors.newFixedThreadPool;


/** Category: Machine Learning
 * ID: Example8
 * Description: Classification with vanilla ParallelKNN algorithm
 * Taken From:
 * Details:
 * TODO
 */
public class Example4 {

    public static Pair, List> createDataSet() throws IOException, IllegalArgumentException {

        // load the data
        Table dataSetTable = TableDataSetLoader.loadDataSet(new File("src/main/resources/datasets/iris_data.csv"));

        List labels = new ArrayList<>();

        Column species  = dataSetTable.column("species");

        for (int i = 0; i < species.size(); i++) {

            String label = (String) species.get(i);

            if(label.equals("Iris-setosa")){

                labels.add(0);
            }
            else if(label.equals("Iris-versicolor")){

                labels.add(1);
            }
            else if(label.equals("Iris-virginica")){

                labels.add(2);
            }
            else{
                throw new IllegalArgumentException("Unknown class");
            }
        }

        Table reducedDataSet = dataSetTable.removeColumns("species").first(dataSetTable.rowCount());
        DenseMatrixSet dataSet = new DenseMatrixSet(RowType.Type.DOUBLE_VECTOR, new RowBuilder());
        dataSet.initializeFrom(reducedDataSet);

        // partition the data set
        List> partitions = RangePartitioner.partition(0, dataSet.m(), 4);

        MatrixRowPartitionPolicy partitionPolicy = new MatrixRowPartitionPolicy(partitions);
        dataSet.setPartitionPolicy(partitionPolicy);

        return PairBuilder.makePair(dataSet, labels);
    }

    public static void main(String[] args) throws IOException, IllegalArgumentException{

        Pair, List> data = Example4.createDataSet();
        ExecutorService executorService = newFixedThreadPool(4);

        System.out.println("Number of rows: "+data.first.m());
        System.out.println("Number of labels: "+data.second.size());


        ThreadedKNNClassifier, EuclideanVectorCalculator,
                ClassificationVoter> classifier = new ThreadedKNNClassifier<>(3, false, executorService);

        classifier.setDistanceCalculator(new EuclideanVectorCalculator());
        classifier.setMajorityVoter(new ClassificationVoter());

        classifier.train(data.first, data.second);
        VectorDouble point = new VectorDouble(5.9,3.0,5.1,1.8);
        Integer classIdx = classifier.predict(point);

        System.out.println("Point "+ point +" has class index "+ classIdx);
        executorService.shutdown();

    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy