All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.utils.clustering.ClusterDumper Maven / Gradle / Ivy

Go to download

Optional components of Mahout which generally support interaction with third party systems, formats, APIs, etc.

There is a newer version: 0.13.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.utils.clustering;

import java.io.File;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import com.google.common.io.Closeables;
import com.google.common.io.Files;
import org.apache.commons.io.Charsets;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.mahout.clustering.cdbw.CDbwEvaluator;
import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable;
import org.apache.mahout.clustering.evaluation.ClusterEvaluator;
import org.apache.mahout.clustering.evaluation.RepresentativePointsDriver;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
import org.apache.mahout.utils.vectors.VectorHelper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ClusterDumper extends AbstractJob {

  public static final String SAMPLE_POINTS = "samplePoints";
  DistanceMeasure measure;

  public enum OUTPUT_FORMAT {
    TEXT,
    CSV,
    GRAPH_ML,
    JSON,
  }

  public static final String DICTIONARY_TYPE_OPTION = "dictionaryType";
  public static final String DICTIONARY_OPTION = "dictionary";
  public static final String POINTS_DIR_OPTION = "pointsDir";
  public static final String NUM_WORDS_OPTION = "numWords";
  public static final String SUBSTRING_OPTION = "substring";
  public static final String EVALUATE_CLUSTERS = "evaluate";

  public static final String OUTPUT_FORMAT_OPT = "outputFormat";

  private static final Logger log = LoggerFactory.getLogger(ClusterDumper.class);
  private Path seqFileDir;
  private Path pointsDir;
  private long maxPointsPerCluster = Long.MAX_VALUE;
  private String termDictionary;
  private String dictionaryFormat;
  private int subString = Integer.MAX_VALUE;
  private int numTopFeatures = 10;
  private Map> clusterIdToPoints;
  private OUTPUT_FORMAT outputFormat = OUTPUT_FORMAT.TEXT;
  private boolean runEvaluation;

  public ClusterDumper(Path seqFileDir, Path pointsDir) {
    this.seqFileDir = seqFileDir;
    this.pointsDir = pointsDir;
    init();
  }

  public ClusterDumper() {
    setConf(new Configuration());
  }

  public static void main(String[] args) throws Exception {
    new ClusterDumper().run(args);
  }

  @Override
  public int run(String[] args) throws Exception {
    addInputOption();
    addOutputOption();
    addOption(OUTPUT_FORMAT_OPT, "of", "The optional output format for the results.  Options: TEXT, CSV, JSON or GRAPH_ML",
        "TEXT");
    addOption(SUBSTRING_OPTION, "b", "The number of chars of the asFormatString() to print");
    addOption(NUM_WORDS_OPTION, "n", "The number of top terms to print");
    addOption(POINTS_DIR_OPTION, "p",
            "The directory containing points sequence files mapping input vectors to their cluster.  "
                    + "If specified, then the program will output the points associated with a cluster");
    addOption(SAMPLE_POINTS, "sp", "Specifies the maximum number of points to include _per_ cluster.  The default "
        + "is to include all points");
    addOption(DICTIONARY_OPTION, "d", "The dictionary file");
    addOption(DICTIONARY_TYPE_OPTION, "dt", "The dictionary file type (text|sequencefile)", "text");
    addOption(buildOption(EVALUATE_CLUSTERS, "e", "Run ClusterEvaluator and CDbwEvaluator over the input.  "
        + "The output will be appended to the rest of the output at the end.", false, false, null));
    addOption(DefaultOptionCreator.distanceMeasureOption().create());

    // output is optional, will print to System.out per default
    if (parseArguments(args, false, true) == null) {
      return -1;
    }

    seqFileDir = getInputPath();
    if (hasOption(POINTS_DIR_OPTION)) {
      pointsDir = new Path(getOption(POINTS_DIR_OPTION));
    }
    outputFile = getOutputFile();
    if (hasOption(SUBSTRING_OPTION)) {
      int sub = Integer.parseInt(getOption(SUBSTRING_OPTION));
      if (sub >= 0) {
        subString = sub;
      }
    }
    termDictionary = getOption(DICTIONARY_OPTION);
    dictionaryFormat = getOption(DICTIONARY_TYPE_OPTION);
    if (hasOption(NUM_WORDS_OPTION)) {
      numTopFeatures = Integer.parseInt(getOption(NUM_WORDS_OPTION));
    }
    if (hasOption(OUTPUT_FORMAT_OPT)) {
      outputFormat = OUTPUT_FORMAT.valueOf(getOption(OUTPUT_FORMAT_OPT));
    }
    if (hasOption(SAMPLE_POINTS)) {
      maxPointsPerCluster = Long.parseLong(getOption(SAMPLE_POINTS));
    } else {
      maxPointsPerCluster = Long.MAX_VALUE;
    }
    runEvaluation = hasOption(EVALUATE_CLUSTERS);
    String distanceMeasureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
    measure = ClassUtils.instantiateAs(distanceMeasureClass, DistanceMeasure.class);

    init();
    printClusters(null);
    return 0;
  }

  public void printClusters(String[] dictionary) throws Exception {
    Configuration conf = new Configuration();

    if (this.termDictionary != null) {
      if ("text".equals(dictionaryFormat)) {
        dictionary = VectorHelper.loadTermDictionary(new File(this.termDictionary));
      } else if ("sequencefile".equals(dictionaryFormat)) {
        dictionary = VectorHelper.loadTermDictionary(conf, this.termDictionary);
      } else {
        throw new IllegalArgumentException("Invalid dictionary format");
      }
    }

    Writer writer;
    boolean shouldClose;
    if (this.outputFile == null) {
      shouldClose = false;
      writer = new OutputStreamWriter(System.out, Charsets.UTF_8);
    } else {
      shouldClose = true;
      if (outputFile.getName().startsWith("s3n://")) {
        Path p = outputPath;
        FileSystem fs = FileSystem.get(p.toUri(), conf);
        writer = new OutputStreamWriter(fs.create(p), Charsets.UTF_8);
      } else {
        Files.createParentDirs(outputFile);
        writer = Files.newWriter(this.outputFile, Charsets.UTF_8);
      }
    }
    ClusterWriter clusterWriter = createClusterWriter(writer, dictionary);
    try {
      long numWritten = clusterWriter.write(new SequenceFileDirValueIterable(new Path(seqFileDir,
          "part-*"), PathType.GLOB, conf));

      writer.flush();
      if (runEvaluation) {
        HadoopUtil.delete(conf, new Path("tmp/representative"));
        int numIters = 5;
        RepresentativePointsDriver.main(new String[]{
          "--input", seqFileDir.toString(),
          "--output", "tmp/representative",
          "--clusteredPoints", pointsDir.toString(),
          "--distanceMeasure", measure.getClass().getName(),
          "--maxIter", String.valueOf(numIters)
        });
        conf.set(RepresentativePointsDriver.DISTANCE_MEASURE_KEY, measure.getClass().getName());
        conf.set(RepresentativePointsDriver.STATE_IN_KEY, "tmp/representative/representativePoints-" + numIters);
        ClusterEvaluator ce = new ClusterEvaluator(conf, seqFileDir);
        writer.append("\n");
        writer.append("Inter-Cluster Density: ").append(String.valueOf(ce.interClusterDensity())).append("\n");
        writer.append("Intra-Cluster Density: ").append(String.valueOf(ce.intraClusterDensity())).append("\n");
        CDbwEvaluator cdbw = new CDbwEvaluator(conf, seqFileDir);
        writer.append("CDbw Inter-Cluster Density: ").append(String.valueOf(cdbw.interClusterDensity())).append("\n");
        writer.append("CDbw Intra-Cluster Density: ").append(String.valueOf(cdbw.intraClusterDensity())).append("\n");
        writer.append("CDbw Separation: ").append(String.valueOf(cdbw.separation())).append("\n");
        writer.flush();
      }
      log.info("Wrote {} clusters", numWritten);
    } finally {
      if (shouldClose) {
        Closeables.close(clusterWriter, false);
      } else {
        if (clusterWriter instanceof GraphMLClusterWriter) {
          clusterWriter.close();
        }
      }
    }
  }

  ClusterWriter createClusterWriter(Writer writer, String[] dictionary) throws IOException {
    ClusterWriter result;

    switch (outputFormat) {
      case TEXT:
        result = new ClusterDumperWriter(writer, clusterIdToPoints, measure, numTopFeatures, dictionary, subString);
        break;
      case CSV:
        result = new CSVClusterWriter(writer, clusterIdToPoints, measure);
        break;
      case GRAPH_ML:
        result = new GraphMLClusterWriter(writer, clusterIdToPoints, measure, numTopFeatures, dictionary, subString);
        break;
      case JSON:
        result = new JsonClusterWriter(writer, clusterIdToPoints, measure, numTopFeatures, dictionary);
        break;
      default:
        throw new IllegalStateException("Unknown outputformat: " + outputFormat);
    }
    return result;
  }

  /**
   * Convenience function to set the output format during testing.
   */
  public void setOutputFormat(OUTPUT_FORMAT of) {
    outputFormat = of;
  }

  private void init() {
    if (this.pointsDir != null) {
      Configuration conf = new Configuration();
      // read in the points
      clusterIdToPoints = readPoints(this.pointsDir, maxPointsPerCluster, conf);
    } else {
      clusterIdToPoints = Collections.emptyMap();
    }
  }


  public int getSubString() {
    return subString;
  }

  public void setSubString(int subString) {
    this.subString = subString;
  }

  public Map> getClusterIdToPoints() {
    return clusterIdToPoints;
  }

  public String getTermDictionary() {
    return termDictionary;
  }

  public void setTermDictionary(String termDictionary, String dictionaryType) {
    this.termDictionary = termDictionary;
    this.dictionaryFormat = dictionaryType;
  }

  public void setNumTopFeatures(int num) {
    this.numTopFeatures = num;
  }

  public int getNumTopFeatures() {
    return this.numTopFeatures;
  }

  public long getMaxPointsPerCluster() {
    return maxPointsPerCluster;
  }

  public void setMaxPointsPerCluster(long maxPointsPerCluster) {
    this.maxPointsPerCluster = maxPointsPerCluster;
  }

  public static Map> readPoints(Path pointsPathDir,
                                                                              long maxPointsPerCluster,
                                                                              Configuration conf) {
    Map> result = new TreeMap<>();
    for (Pair record
        : new SequenceFileDirIterable(pointsPathDir, PathType.LIST,
            PathFilters.logsCRCFilter(), conf)) {
      // value is the cluster id as an int, key is the name/id of the
      // vector, but that doesn't matter because we only care about printing it
      //String clusterId = value.toString();
      int keyValue = record.getFirst().get();
      List pointList = result.get(keyValue);
      if (pointList == null) {
        pointList = new ArrayList<>();
        result.put(keyValue, pointList);
      }
      if (pointList.size() < maxPointsPerCluster) {
        pointList.add(record.getSecond());
      }
    }
    return result;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy