All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spark.examples.JavaKMeans Maven / Gradle / Ivy

package spark.examples;

import scala.Tuple2;
import spark.api.java.JavaPairRDD;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.api.java.function.Function;
import spark.api.java.function.PairFunction;
import spark.util.Vector;

import java.util.List;
import java.util.Map;

/**
 * K-means clustering using Java API.
 */
public class JavaKMeans {

  /** Parses numbers split by whitespace to a vector */
  static Vector parseVector(String line) {
    String[] splits = line.split(" ");
    double[] data = new double[splits.length];
    int i = 0;
    for (String s : splits)
      data[i] = Double.parseDouble(splits[i++]);
    return new Vector(data);
  }

  /** Computes the vector to which the input vector is closest using squared distance */
  static int closestPoint(Vector p, List centers) {
    int bestIndex = 0;
    double closest = Double.POSITIVE_INFINITY;
    for (int i = 0; i < centers.size(); i++) {
      double tempDist = p.squaredDist(centers.get(i));
      if (tempDist < closest) {
        closest = tempDist;
        bestIndex = i;
      }
    }
    return bestIndex;
  }

  /** Computes the mean across all vectors in the input set of vectors */
  static Vector average(List ps) {
    int numVectors = ps.size();
    Vector out = new Vector(ps.get(0).elements());
    // start from i = 1 since we already copied index 0 above
    for (int i = 1; i < numVectors; i++) {
      out.addInPlace(ps.get(i));
    }
    return out.divide(numVectors);
  }

  public static void main(String[] args) throws Exception {
    if (args.length < 4) {
      System.err.println("Usage: JavaKMeans    ");
      System.exit(1);
    }
    JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
      System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
    String path = args[1];
    int K = Integer.parseInt(args[2]);
    double convergeDist = Double.parseDouble(args[3]);

    JavaRDD data = sc.textFile(path).map(
      new Function() {
        @Override
        public Vector call(String line) throws Exception {
          return parseVector(line);
        }
      }
    ).cache();

    final List centroids = data.takeSample(false, K, 42);

    double tempDist;
    do {
      // allocate each vector to closest centroid
      JavaPairRDD closest = data.map(
        new PairFunction() {
          @Override
          public Tuple2 call(Vector vector) throws Exception {
            return new Tuple2(
              closestPoint(vector, centroids), vector);
          }
        }
      );

      // group by cluster id and average the vectors within each cluster to compute centroids
      JavaPairRDD> pointsGroup = closest.groupByKey();
      Map newCentroids = pointsGroup.mapValues(
        new Function, Vector>() {
          public Vector call(List ps) throws Exception {
            return average(ps);
          }
        }).collectAsMap();
      tempDist = 0.0;
      for (int i = 0; i < K; i++) {
        tempDist += centroids.get(i).squaredDist(newCentroids.get(i));
      }
      for (Map.Entry t: newCentroids.entrySet()) {
        centroids.set(t.getKey(), t.getValue());
      }
      System.out.println("Finished iteration (delta = " + tempDist + ")");
    } while (tempDist > convergeDist);

    System.out.println("Final centers:");
    for (Vector c : centroids)
      System.out.println(c);

    System.exit(0);

  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy