All Downloads are FREE. Search and download functionalities are using the official Maven repository.

tech.tablesaw.api.ml.clustering.Kmeans Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package tech.tablesaw.api.ml.clustering;

import smile.clustering.KMeans;
import tech.tablesaw.api.CategoryColumn;
import tech.tablesaw.api.FloatColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.columns.Column;
import tech.tablesaw.util.DoubleArrays;

/**
 * K-Means clustering
 */
public class Kmeans {

    private final KMeans kMeans;
    private final NumericColumn[] inputColumns;

    public Kmeans(int k, NumericColumn... columns) {
        double[][] input = DoubleArrays.to2dArray(columns);
        this.kMeans = new KMeans(input, k);
        this.inputColumns = columns;
    }

    public Kmeans(int k, int maxIterations, NumericColumn... columns) {
        double[][] input = DoubleArrays.to2dArray(columns);
        this.kMeans = new KMeans(input, k, maxIterations);
        this.inputColumns = columns;
    }

    public int predict(double[] x) {
        return kMeans.predict(x);
    }

    public double[][] centroids() {
        return kMeans.centroids();
    }

    public double distortion() {
        return kMeans.distortion();
    }

    public int getClusterCount() {
        return kMeans.getNumClusters();
    }

    public int[] getClusterLabels() {
        return kMeans.getClusterLabel();
    }

    public int[] getClusterSizes() {
        return kMeans.getClusterSize();
    }

    public Table clustered(Column labels) {
        Table table = Table.create("Clusters");
        CategoryColumn labelColumn = new CategoryColumn("Label");
        IntColumn clusterColumn = new IntColumn("Cluster");
        table.addColumn(labelColumn);
        table.addColumn(clusterColumn);
        int[] clusters = kMeans.getClusterLabel();
        for (int i = 0; i < clusters.length; i++) {
            labelColumn.appendCell(labels.getString(i));
            clusterColumn.append(clusters[i]);
        }
        table = table.sortAscendingOn("Cluster", "Label");
        return table;
    }

    public Table labeledCentroids() {
        Table table = Table.create("Centroids");
        CategoryColumn labelColumn = new CategoryColumn("Cluster");
        table.addColumn(labelColumn);

        for (NumericColumn inputColumn : inputColumns) {
            FloatColumn centroid = new FloatColumn(inputColumn.name());
            table.addColumn(centroid);
        }

        double[][] centroids = kMeans.centroids();

        for (int i = 0; i < centroids.length; i++) {
            labelColumn.appendCell(String.valueOf(i));
            double[] values = centroids[i];
            for (int k = 0; k < values.length; k++) {
                table.floatColumn(k + 1).append((float) values[k]);
            }
        }
        return table;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy