org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansThread Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.streaming.mapreduce;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class StreamingKMeansThread implements Callable> {
private static final Logger log = LoggerFactory.getLogger(StreamingKMeansThread.class);
private static final int NUM_ESTIMATE_POINTS = 1000;
private final Configuration conf;
private final Iterable dataPoints;
public StreamingKMeansThread(Path input, Configuration conf) {
this(StreamingKMeansUtilsMR.getCentroidsFromVectorWritable(
new SequenceFileValueIterable(input, false, conf)), conf);
}
public StreamingKMeansThread(Iterable dataPoints, Configuration conf) {
this.dataPoints = dataPoints;
this.conf = conf;
}
@Override
public Iterable call() {
UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
double estimateDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF);
Iterator dataPointsIterator = dataPoints.iterator();
if (estimateDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) {
List estimatePoints = new ArrayList<>(NUM_ESTIMATE_POINTS);
while (dataPointsIterator.hasNext() && estimatePoints.size() < NUM_ESTIMATE_POINTS) {
Centroid centroid = dataPointsIterator.next();
estimatePoints.add(centroid);
}
if (log.isInfoEnabled()) {
log.info("Estimated Points: {}", estimatePoints.size());
}
estimateDistanceCutoff = ClusteringUtils.estimateDistanceCutoff(estimatePoints, searcher.getDistanceMeasure());
}
StreamingKMeans streamingKMeans = new StreamingKMeans(searcher, numClusters, estimateDistanceCutoff);
// datapointsIterator could be empty if no estimate distance was initially provided
// hence creating the iterator again here for the clustering
if (!dataPointsIterator.hasNext()) {
dataPointsIterator = dataPoints.iterator();
}
while (dataPointsIterator.hasNext()) {
streamingKMeans.cluster(dataPointsIterator.next());
}
streamingKMeans.reindexCentroids();
return streamingKMeans;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy