All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.utils.clustering.JsonClusterWriter Maven / Gradle / Ivy

Go to download

Optional components of Mahout which generally support interaction with third party systems, formats, APIs, etc.

The newest version!
/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mahout.utils.clustering;

import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.codehaus.jackson.map.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Dump cluster info to JSON formatted lines. Heavily inspired by
 * ClusterDumperWriter.java and CSVClusterWriter.java
 *
 */
public class JsonClusterWriter extends AbstractClusterWriter {
  private final String[] dictionary;
  private final int numTopFeatures;
  private final ObjectMapper jxn;

  private static final Logger log = LoggerFactory.getLogger(JsonClusterWriter.class);
  private static final Pattern VEC_PATTERN = Pattern.compile("\\{|\\:|\\,|\\}");

  public JsonClusterWriter(Writer writer,
      Map> clusterIdToPoints,
      DistanceMeasure measure, int numTopFeatures, String[] dictionary) {
    super(writer, clusterIdToPoints, measure);
    this.numTopFeatures = numTopFeatures;
    this.dictionary = dictionary;
    jxn = new ObjectMapper();
  }

  /**
   * Generate HashMap with cluster info and write as a single JSON formatted
   * line
   */
  @Override
  public void write(ClusterWritable clusterWritable) throws IOException {
    Map res = new HashMap<>();

    // get top terms
    if (dictionary != null) {
      List topTerms = getTopFeaturesList(clusterWritable.getValue()
          .getCenter(), dictionary, numTopFeatures);
      res.put("top_terms", topTerms);
    } else {
      res.put("top_terms", new ArrayList<>());
    }

    // get human-readable cluster representation
    Cluster cluster = clusterWritable.getValue();
    res.put("cluster_id", cluster.getId());

    if (dictionary != null) {
      Map fmtStr = cluster.asJson(dictionary);
      res.put("cluster", fmtStr);

      // get points
      List points = getPoints(cluster, dictionary);
      res.put("points", points);
    } else {
      res.put("cluster", new HashMap<>());
      res.put("points", new ArrayList<>());
    }

    // write JSON
    Writer writer = getWriter();
    writer.write(jxn.writeValueAsString(res) + "\n");
  }

  /**
   * Create a List of HashMaps containing top terms information
   *
   * @return List
   */
  public List getTopFeaturesList(Vector vector, String[] dictionary,
      int numTerms) {

    List vectorTerms = new ArrayList<>();

    for (Vector.Element elt : vector.nonZeroes()) {
      vectorTerms.add(new TermIndexWeight(elt.index(), elt.get()));
    }

    // Sort results in reverse order (i.e. weight in descending order)
    Collections.sort(vectorTerms, new Comparator() {
      @Override
      public int compare(TermIndexWeight one, TermIndexWeight two) {
        return Double.compare(two.weight, one.weight);
      }
    });

    List topTerms = new ArrayList<>();

    for (int i = 0; i < vectorTerms.size() && i < numTerms; i++) {
      int index = vectorTerms.get(i).index;
      String dictTerm = dictionary[index];
      if (dictTerm == null) {
        log.error("Dictionary entry missing for {}", index);
        continue;
      }
      Map term_entry = new HashMap<>();
      term_entry.put(dictTerm, vectorTerms.get(i).weight);
      topTerms.add(term_entry);
    }

    return topTerms;
  }

  /**
   * Create a List of HashMaps containing Vector point information
   *
   * @return List
   */
  public List getPoints(Cluster cluster, String[] dictionary) {
    List vectorObjs = new ArrayList<>();
    List points = getClusterIdToPoints().get(
        cluster.getId());

    if (points != null) {
      for (WeightedPropertyVectorWritable point : points) {
        Map entry = new HashMap<>();
        Vector theVec = point.getVector();
        if (theVec instanceof NamedVector) {
          entry.put("vector_name", ((NamedVector) theVec).getName());
        } else {
          String vecStr = theVec.asFormatString();
          // do some basic manipulations for display
          vecStr = VEC_PATTERN.matcher(vecStr).replaceAll("_");
          entry.put("vector_name", vecStr);
        }
        entry.put("weight", String.valueOf(point.getWeight()));
        try {
          entry.put("point",
                  AbstractCluster.formatVectorAsJson(point.getVector(), dictionary));
        } catch (IOException e) {
          log.error("IOException:  ", e);
        }
        vectorObjs.add(entry);
      }
    }
    return vectorObjs;
  }

  /**
   * Convenience class for sorting terms
   *
   */
  private static class TermIndexWeight {
    private final int index;
    private final double weight;

    TermIndexWeight(int index, double weight) {
      this.index = index;
      this.weight = weight;
    }
  }

}