![JAR search and dependency download from the Maven repository](/logo.png)
spark.examples.JavaKMeans Maven / Gradle / Ivy
package spark.examples;
import scala.Tuple2;
import spark.api.java.JavaPairRDD;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.api.java.function.Function;
import spark.api.java.function.PairFunction;
import spark.util.Vector;
import java.util.List;
import java.util.Map;
/**
* K-means clustering using Java API.
*/
public class JavaKMeans {
/** Parses numbers split by whitespace to a vector */
static Vector parseVector(String line) {
String[] splits = line.split(" ");
double[] data = new double[splits.length];
int i = 0;
for (String s : splits)
data[i] = Double.parseDouble(splits[i++]);
return new Vector(data);
}
/** Computes the vector to which the input vector is closest using squared distance */
static int closestPoint(Vector p, List centers) {
int bestIndex = 0;
double closest = Double.POSITIVE_INFINITY;
for (int i = 0; i < centers.size(); i++) {
double tempDist = p.squaredDist(centers.get(i));
if (tempDist < closest) {
closest = tempDist;
bestIndex = i;
}
}
return bestIndex;
}
/** Computes the mean across all vectors in the input set of vectors */
static Vector average(List ps) {
int numVectors = ps.size();
Vector out = new Vector(ps.get(0).elements());
// start from i = 1 since we already copied index 0 above
for (int i = 1; i < numVectors; i++) {
out.addInPlace(ps.get(i));
}
return out.divide(numVectors);
}
public static void main(String[] args) throws Exception {
if (args.length < 4) {
System.err.println("Usage: JavaKMeans ");
System.exit(1);
}
JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
String path = args[1];
int K = Integer.parseInt(args[2]);
double convergeDist = Double.parseDouble(args[3]);
JavaRDD data = sc.textFile(path).map(
new Function() {
@Override
public Vector call(String line) throws Exception {
return parseVector(line);
}
}
).cache();
final List centroids = data.takeSample(false, K, 42);
double tempDist;
do {
// allocate each vector to closest centroid
JavaPairRDD closest = data.map(
new PairFunction() {
@Override
public Tuple2 call(Vector vector) throws Exception {
return new Tuple2(
closestPoint(vector, centroids), vector);
}
}
);
// group by cluster id and average the vectors within each cluster to compute centroids
JavaPairRDD> pointsGroup = closest.groupByKey();
Map newCentroids = pointsGroup.mapValues(
new Function, Vector>() {
public Vector call(List ps) throws Exception {
return average(ps);
}
}).collectAsMap();
tempDist = 0.0;
for (int i = 0; i < K; i++) {
tempDist += centroids.get(i).squaredDist(newCentroids.get(i));
}
for (Map.Entry t: newCentroids.entrySet()) {
centroids.set(t.getKey(), t.getValue());
}
System.out.println("Finished iteration (delta = " + tempDist + ")");
} while (tempDist > convergeDist);
System.out.println("Final centers:");
for (Vector c : centroids)
System.out.println(c);
System.exit(0);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy