All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.utils.clustering.AbstractClusterWriter Maven / Gradle / Ivy

Go to download

Optional components of Mahout which generally support interaction with third party systems, formats, APIs, etc.

There is a newer version: 0.13.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.utils.clustering;

import java.io.IOException;
import java.io.Writer;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.Lists;

/**
 * Base class for implementing ClusterWriter
 */
public abstract class AbstractClusterWriter implements ClusterWriter {

  private static final Logger log = LoggerFactory.getLogger(AbstractClusterWriter.class);

  protected final Writer writer;
  protected final Map> clusterIdToPoints;
  protected final DistanceMeasure measure;

  /**
   *
   * @param writer The underlying {@link java.io.Writer} to use
   * @param clusterIdToPoints The map between cluster ids {@link org.apache.mahout.clustering.Cluster#getId()} and the
   *                          points in the cluster
   * @param measure The {@link org.apache.mahout.common.distance.DistanceMeasure} used to calculate the distance.
   *                Some writers may wish to use it for calculating weights for display.  May be null.
   */
  protected AbstractClusterWriter(Writer writer, Map> clusterIdToPoints,
      DistanceMeasure measure) {
    this.writer = writer;
    this.clusterIdToPoints = clusterIdToPoints;
    this.measure = measure;
  }

  protected Writer getWriter() {
    return writer;
  }

  protected Map> getClusterIdToPoints() {
    return clusterIdToPoints;
  }

  public static String getTopFeatures(Vector vector, String[] dictionary, int numTerms) {

    StringBuilder sb = new StringBuilder(100);

    for (Pair item : getTopPairs(vector, dictionary, numTerms)) {
      String term = item.getFirst();
      sb.append("\n\t\t");
      sb.append(StringUtils.rightPad(term, 40));
      sb.append("=>");
      sb.append(StringUtils.leftPad(item.getSecond().toString(), 20));
    }
    return sb.toString();
  }

  public static String getTopTerms(Vector vector, String[] dictionary, int numTerms) {

    StringBuilder sb = new StringBuilder(100);

    for (Pair item : getTopPairs(vector, dictionary, numTerms)) {
      String term = item.getFirst();
      sb.append(term).append('_');
    }
    sb.deleteCharAt(sb.length() - 1);
    return sb.toString();
  }

  @Override
  public long write(Iterable iterable) throws IOException {
    return write(iterable, Long.MAX_VALUE);
  }

  @Override
  public void close() throws IOException {
    writer.close();
  }

  @Override
  public long write(Iterable iterable, long maxDocs) throws IOException {
    long result = 0;
    Iterator iterator = iterable.iterator();
    while (result < maxDocs && iterator.hasNext()) {
      write(iterator.next());
      result++;
    }
    return result;
  }

  private static Collection> getTopPairs(Vector vector, String[] dictionary, int numTerms) {
    List vectorTerms = Lists.newArrayList();

    for (Vector.Element elt : vector.nonZeroes()) {
      vectorTerms.add(new TermIndexWeight(elt.index(), elt.get()));
    }

    // Sort results in reverse order (ie weight in descending order)
    Collections.sort(vectorTerms, new Comparator() {
      @Override
      public int compare(TermIndexWeight one, TermIndexWeight two) {
        return Double.compare(two.weight, one.weight);
      }
    });

    Collection> topTerms = Lists.newLinkedList();

    for (int i = 0; i < vectorTerms.size() && i < numTerms; i++) {
      int index = vectorTerms.get(i).index;
      String dictTerm = dictionary[index];
      if (dictTerm == null) {
        log.error("Dictionary entry missing for {}", index);
        continue;
      }
      topTerms.add(new Pair<>(dictTerm, vectorTerms.get(i).weight));
    }

    return topTerms;
  }

  private static class TermIndexWeight {
    private final int index;
    private final double weight;

    TermIndexWeight(int index, double weight) {
      this.index = index;
      this.weight = weight;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy