All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.lwhite1.tablesaw.api.ml.clustering.Xmeans Maven / Gradle / Ivy

package com.github.lwhite1.tablesaw.api.ml.clustering;

import com.github.lwhite1.tablesaw.api.CategoryColumn;
import com.github.lwhite1.tablesaw.api.FloatColumn;
import com.github.lwhite1.tablesaw.api.NumericColumn;
import com.github.lwhite1.tablesaw.api.Table;
import com.github.lwhite1.tablesaw.util.DoubleArrays;
import smile.clustering.XMeans;

/**
 *
 */
public class Xmeans {

  private final XMeans model;

  private final NumericColumn[] inputColumns;


  public Xmeans(int maxK, NumericColumn... columns) {

    double[][] data = DoubleArrays.to2dArray(columns);
    this.model = new XMeans(data, maxK);
    this.inputColumns = columns;
  }

  public int predict(double[] x) {
    return model.predict(x);
  }

  public double[][] centroids() {
    return model.centroids();
  }

  public double distortion() {
    return model.distortion();
  }

  public int getClusterCount() {
    return model.getNumClusters();
  }

  public int[] getClusterLabels() {
    return model.getClusterLabel();
  }

  public int[] getClusterSizes() {
    return model.getClusterSize();
  }

  public Table labeledCentroids() {
    Table table = Table.create("Centroids");
    CategoryColumn labelColumn = CategoryColumn.create("Cluster");
    table.addColumn(labelColumn);

    for (int i = 0; i < inputColumns.length; i++) {
      FloatColumn centroid = FloatColumn.create(inputColumns[i].name());
      table.addColumn(centroid);
    }

    double[][] centroids = model.centroids();

    for (int i = 0 ; i < centroids.length; i++) {
      labelColumn.addCell(String.valueOf(i));
      double[] values = centroids[i];
      for (int k = 0; k < values.length; k++) {
        table.floatColumn(k + 1).add((float) values[k]);
      }
    }
    return table;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy