All Downloads are FREE. Search and download functionalities are using the official Maven repository.

co.cask.cdap.examples.wordcount.AssociationTable Maven / Gradle / Ivy

/*
 * Copyright © 2014 Cask Data, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package co.cask.cdap.examples.wordcount;

import co.cask.cdap.api.annotation.ReadOnly;
import co.cask.cdap.api.annotation.WriteOnly;
import co.cask.cdap.api.common.Bytes;
import co.cask.cdap.api.dataset.DatasetSpecification;
import co.cask.cdap.api.dataset.lib.AbstractDataset;
import co.cask.cdap.api.dataset.module.EmbeddedDataset;
import co.cask.cdap.api.dataset.table.Get;
import co.cask.cdap.api.dataset.table.Row;
import co.cask.cdap.api.dataset.table.Table;

import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

/**
 *
 */
public class AssociationTable extends AbstractDataset {

  private Table table;

  public AssociationTable(DatasetSpecification spec, @EmbeddedDataset("word_assoc") Table table) {
    super(spec.getName(), table);
    this.table = table;
  }

  /**
   * Stores associations between the specified set of words. That is, for every
   * word in the set, an association will be stored for each of the other words
   * in the set.
   * @param words words to store associations between
   */
  @WriteOnly
  public void writeWordAssocs(Set words) {

    // For sets of less than 2 words, there are no associations
    int n = words.size();

    if (n < 2) {
      return;
    }

    // Every word will get (n-1) increments (one for each of the other words)
    long[] values = new long[n - 1];
    Arrays.fill(values, 1);

    // Convert all words to bytes
    byte[][] wordBytes = new byte[n][];
    int i = 0;
    for (String word : words) {
      wordBytes[i++] = Bytes.toBytes(word);
    }

    // Generate an increment for each word
    for (int j = 0; j < n; j++) {
      byte[] row =  wordBytes[j];
      byte[][] columns = new byte[n - 1][];
      System.arraycopy(wordBytes, 0, columns, 0, j);
      System.arraycopy(wordBytes, j + 1, columns, j, n - j - 1);
      this.table.increment(row, columns, values);
    }
  }

  /**
   * Returns the top words associated with the specified word and the number
   * of times the words have appeared together.
   * @param word the word of interest
   * @param limit the number of associations to return, at most
   * @return a map of the top associated words to their co-occurrence count
   */
  @ReadOnly
  public Map readWordAssocs(String word, int limit) {

    // Retrieve all columns of the word’s row
    Row result = this.table.get(new Get(word));
    TopKCollector collector = new TopKCollector(limit);
    if (!result.isEmpty()) {
      
      // Iterate over all columns
      for (Map.Entry entry : result.getColumns().entrySet()) {
        collector.add(Bytes.toLong(entry.getValue()),
                      Bytes.toString(entry.getKey()));
      }
    }
    return collector.getTopK();
  }

  /**
   * Returns how many times two words occured together.
   * @param word1 the first word
   * @param word2 the other word
   * @return how many times word1 and word2 occurred together
   */
  @ReadOnly
  public long getAssoc(String word1, String word2) {
    Long val = table.get(new Get(word1, word2)).getLong(word2);
    return val == null ? 0 : val;
  }
}

class TopKCollector {

  class Entry implements Comparable {
    final long count;
    final String word;

    Entry(long count, String word) {
      this.count = count;
      this.word = word;
    }

    @Override
    public int compareTo(Entry other) {
      if (count == other.count) {
        return word.compareTo(other.word);
      }
      return Long.signum(count - other.count);
    }
  }

  final int limit;
  TreeSet entries = new TreeSet<>();

  TopKCollector(int limit) {
    this.limit = limit;
  }

  void add(long count, String word) {
    if (entries.size() < limit) {
      entries.add(new Entry(count, word));
    } else {
      if (entries.first().count < count) {
        entries.pollFirst();
        entries.add(new Entry(count, word));
      }
    }
  }

  Map getTopK() {
    TreeMap topK = new TreeMap<>();
    for (int i = 0; i < limit; i++) {
      Entry entry = entries.pollLast();
      if (entry == null) {
        break;
      } else {
        topK.put(entry.word, entry.count);
      }
    }
    return topK;
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy