All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.clustering.classify.ClusterClassificationMapper Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.clustering.classify;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.clustering.iterator.ClusteringPolicy;
import org.apache.mahout.clustering.iterator.DistanceMeasureCluster;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;

/**
 * Mapper for classifying vectors into clusters.
 */
public class ClusterClassificationMapper extends
    Mapper,VectorWritable,IntWritable,WeightedVectorWritable> {
  
  private double threshold;
  private List clusterModels;
  private ClusterClassifier clusterClassifier;
  private IntWritable clusterId;
  private boolean emitMostLikely;
  
  @Override
  protected void setup(Context context) throws IOException, InterruptedException {
    super.setup(context);
    
    Configuration conf = context.getConfiguration();
    String clustersIn = conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
    threshold = conf.getFloat(ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD, 0.0f);
    emitMostLikely = conf.getBoolean(ClusterClassificationConfigKeys.EMIT_MOST_LIKELY, false);
    
    clusterModels = new ArrayList<>();
    
    if (clustersIn != null && !clustersIn.isEmpty()) {
      Path clustersInPath = new Path(clustersIn);
      clusterModels = populateClusterModels(clustersInPath, conf);
      ClusteringPolicy policy = ClusterClassifier
          .readPolicy(finalClustersPath(clustersInPath));
      clusterClassifier = new ClusterClassifier(clusterModels, policy);
    }
    clusterId = new IntWritable();
  }
  
  /**
   * Mapper which classifies the vectors to respective clusters.
   */
  @Override
  protected void map(WritableComparable key, VectorWritable vw, Context context)
    throws IOException, InterruptedException {
    if (!clusterModels.isEmpty()) {
      // Converting to NamedVectors to preserve the vectorId else its not obvious as to which point
      // belongs to which cluster - fix for MAHOUT-1410
      Class vectorClass = vw.get().getClass();
      Vector vector = vw.get();
      if (!vectorClass.equals(NamedVector.class)) {
        if (key.getClass().equals(Text.class)) {
          vector = new NamedVector(vector, key.toString());
        } else if (key.getClass().equals(IntWritable.class)) {
          vector = new NamedVector(vector, Integer.toString(((IntWritable) key).get()));
        }
      }
      Vector pdfPerCluster = clusterClassifier.classify(vector);
      if (shouldClassify(pdfPerCluster)) {
        if (emitMostLikely) {
          int maxValueIndex = pdfPerCluster.maxValueIndex();
          write(new VectorWritable(vector), context, maxValueIndex, 1.0);
        } else {
          writeAllAboveThreshold(new VectorWritable(vector), context, pdfPerCluster);
        }
      }
    }
  }
  
  private void writeAllAboveThreshold(VectorWritable vw, Context context,
      Vector pdfPerCluster) throws IOException, InterruptedException {
    for (Element pdf : pdfPerCluster.nonZeroes()) {
      if (pdf.get() >= threshold) {
        int clusterIndex = pdf.index();
        write(vw, context, clusterIndex, pdf.get());
      }
    }
  }
  
  private void write(VectorWritable vw, Context context, int clusterIndex, double weight)
    throws IOException, InterruptedException {
    Cluster cluster = clusterModels.get(clusterIndex);
    clusterId.set(cluster.getId());

    DistanceMeasureCluster distanceMeasureCluster = (DistanceMeasureCluster) cluster;
    DistanceMeasure distanceMeasure = distanceMeasureCluster.getMeasure();
    double distance = distanceMeasure.distance(cluster.getCenter(), vw.get());

    Map props = new HashMap<>();
    props.put(new Text("distance"), new Text(Double.toString(distance)));
    context.write(clusterId, new WeightedPropertyVectorWritable(weight, vw.get(), props));
  }
  
  public static List populateClusterModels(Path clusterOutputPath, Configuration conf) throws IOException {
    List clusters = new ArrayList<>();
    FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
    Iterator it = new SequenceFileDirValueIterator(
        clusterFiles[0].getPath(), PathType.LIST, PathFilters.partFilter(),
        null, false, conf);
    while (it.hasNext()) {
      ClusterWritable next = (ClusterWritable) it.next();
      Cluster cluster = next.getValue();
      cluster.configure(conf);
      clusters.add(cluster);
    }
    return clusters;
  }
  
  private boolean shouldClassify(Vector pdfPerCluster) {
    return pdfPerCluster.maxValue() >= threshold;
  }
  
  private static Path finalClustersPath(Path clusterOutputPath) throws IOException {
    FileSystem fileSystem = clusterOutputPath.getFileSystem(new Configuration());
    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
    return clusterFiles[0].getPath();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy