All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansUtilsMR Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.clustering.streaming.mapreduce;

import java.io.IOException;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;

public final class StreamingKMeansUtilsMR {

  private StreamingKMeansUtilsMR() {
  }

  /**
   * Instantiates a searcher from a given configuration.
   * @param conf the configuration
   * @return the instantiated searcher
   * @throws RuntimeException if the distance measure class cannot be instantiated
   * @throws IllegalStateException if an unknown searcher class was requested
   */
  public static UpdatableSearcher searcherFromConfiguration(Configuration conf) {
    DistanceMeasure distanceMeasure;
    String distanceMeasureClass = conf.get(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
    try {
      distanceMeasure = (DistanceMeasure) Class.forName(distanceMeasureClass).getConstructor().newInstance();
    } catch (Exception e) {
      throw new RuntimeException("Failed to instantiate distanceMeasure", e);
    }

    int numProjections =  conf.getInt(StreamingKMeansDriver.NUM_PROJECTIONS_OPTION, 20);
    int searchSize =  conf.getInt(StreamingKMeansDriver.SEARCH_SIZE_OPTION, 10);

    String searcherClass = conf.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION);

    if (searcherClass.equals(BruteSearch.class.getName())) {
      return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class,
          new Class[]{DistanceMeasure.class}, new Object[]{distanceMeasure});
    } else if (searcherClass.equals(FastProjectionSearch.class.getName())
        || searcherClass.equals(ProjectionSearch.class.getName())) {
      return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class,
          new Class[]{DistanceMeasure.class, int.class, int.class},
          new Object[]{distanceMeasure, numProjections, searchSize});
    } else if (searcherClass.equals(LocalitySensitiveHashSearch.class.getName())) {
      return ClassUtils.instantiateAs(searcherClass, LocalitySensitiveHashSearch.class,
          new Class[]{DistanceMeasure.class, int.class},
          new Object[]{distanceMeasure, searchSize});
    } else {
      throw new IllegalStateException("Unknown class instantiation requested");
    }
  }

  /**
   * Returns an Iterable of centroids from an Iterable of VectorWritables by creating a new Centroid containing
   * a RandomAccessSparseVector as a delegate for each VectorWritable.
   * @param inputIterable VectorWritable Iterable to get Centroids from
   * @return the new Centroids
   */
  public static Iterable getCentroidsFromVectorWritable(Iterable inputIterable) {
    return Iterables.transform(inputIterable, new Function() {
      private int numVectors = 0;
      @Override
      public Centroid apply(VectorWritable input) {
        Preconditions.checkNotNull(input);
        return new Centroid(numVectors++, new RandomAccessSparseVector(input.get()), 1);
      }
    });
  }

  /**
   * Returns an Iterable of Centroid from an Iterable of Vector by either casting each Vector to Centroid (if the
   * instance extends Centroid) or create a new Centroid based on that Vector.
   * The implicit expectation is that the input will not have interleaving types of vectors. Otherwise, the numbering
   * of new Centroids will become invalid.
   * @param input Iterable of Vectors to cast
   * @return the new Centroids
   */
  public static Iterable castVectorsToCentroids(Iterable input) {
    return Iterables.transform(input, new Function() {
      private int numVectors = 0;
      @Override
      public Centroid apply(Vector input) {
        Preconditions.checkNotNull(input);
        if (input instanceof Centroid) {
          return (Centroid) input;
        } else {
          return new Centroid(numVectors++, input, 1);
        }
      }
    });
  }

  /**
   * Writes centroids to a sequence file.
   * @param centroids the centroids to write.
   * @param path the path of the output file.
   * @param conf the configuration for the HDFS to write the file to.
   * @throws java.io.IOException
   */
  public static void writeCentroidsToSequenceFile(Iterable centroids, Path path, Configuration conf)
    throws IOException {
    try (SequenceFile.Writer writer = SequenceFile.createWriter(FileSystem.get(conf), conf,
        path, IntWritable.class, CentroidWritable.class)) {
      int i = 0;
      for (Centroid centroid : centroids) {
        writer.append(new IntWritable(i++), new CentroidWritable(centroid));
      }
    }
  }

  public static void writeVectorsToSequenceFile(Iterable datapoints, Path path, Configuration conf)
    throws IOException {
    try (SequenceFile.Writer writer = SequenceFile.createWriter(FileSystem.get(conf), conf,
        path, IntWritable.class, VectorWritable.class)){
      int i = 0;
      for (Vector vector : datapoints) {
        writer.append(new IntWritable(i++), new VectorWritable(vector));
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy