All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.udf.generic.NGramEstimator Maven / Gradle / Ivy

There is a newer version: 4.0.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hive.ql.udf.generic;

import java.util.List;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Collections;
import java.util.Iterator;
import java.util.Comparator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * A generic, re-usable n-gram estimation class that supports partial aggregations.
 * The algorithm is based on the heuristic from the following paper:
 * Yael Ben-Haim and Elad Tom-Tov, "A streaming parallel decision tree algorithm",
 * J. Machine Learning Research 11 (2010), pp. 849--872.
 *
 * In particular, it is guaranteed that frequencies will be under-counted. With large
 * data and a reasonable precision factor, this undercounting appears to be on the order
 * of 5%.
 */
public class NGramEstimator {
  /* Class private variables */
  private int k;
  private int pf;
  private int n;
  private HashMap, Double> ngrams;


  /**
   * Creates a new n-gram estimator object. The 'n' for n-grams is computed dynamically
   * when data is fed to the object.
   */
  public NGramEstimator() {
    k  = 0;
    pf = 0;
    n  = 0;
    ngrams = new HashMap, Double>();
  }

  /**
   * Returns true if the 'k' and 'pf' parameters have been set.
   */
  public boolean isInitialized() {
    return (k != 0);
  }

  /**
   * Sets the 'k' and 'pf' parameters.
   */
  public void initialize(int pk, int ppf, int pn) throws HiveException {
    assert(pk > 0 && ppf > 0 && pn > 0);
    k = pk;
    pf = ppf;
    n = pn;

    // enforce a minimum precision factor
    if(k * pf < 1000) {
      pf = 1000 / k;
    }
  }

  /**
   * Resets an n-gram estimator object to its initial state.
   */
  public void reset() {
    ngrams.clear();
    n = pf = k = 0;
  }

  /**
   * Returns the final top-k n-grams in a format suitable for returning to Hive.
   */
  public ArrayList getNGrams() throws HiveException {
    trim(true);
    if(ngrams.size() < 1) { // SQL standard - return null for zero elements
      return null;
    }

    // Sort the n-gram list by frequencies in descending order
    ArrayList result = new ArrayList();
    ArrayList, Double>> list = new ArrayList(ngrams.entrySet());
    Collections.sort(list, new Comparator, Double>>() {
      public int compare(Map.Entry, Double> o1,
                         Map.Entry, Double> o2) {
        int result = o2.getValue().compareTo(o1.getValue());
        if (result != 0)
          return result;
        
        ArrayList key1 = o1.getKey();
        ArrayList key2 = o2.getKey();
        for (int i = 0; i < key1.size() && i < key2.size(); i++) {
          result = key1.get(i).compareTo(key2.get(i));
          if (result != 0)
            return result;
        }
        
        return key1.size() - key2.size();
      }
    });

    // Convert the n-gram list to a format suitable for Hive
    for(int i = 0; i < list.size(); i++) {
      ArrayList key = list.get(i).getKey();
      Double val = list.get(i).getValue();

      Object[] curGram = new Object[2];
      ArrayList ng = new ArrayList();
      for(int j = 0; j < key.size(); j++) {
        ng.add(new Text(key.get(j)));
      }
      curGram[0] = ng;
      curGram[1] = new DoubleWritable(val.doubleValue());
      result.add(curGram);
    }

    return result;
  }

  /**
   * Returns the number of n-grams in our buffer.
   */
  public int size() {
    return ngrams.size();
  }

  /**
   * Adds a new n-gram to the estimation.
   *
   * @param ng The n-gram to add to the estimation
   */
  public void add(ArrayList ng) throws HiveException {
    assert(ng != null && ng.size() > 0 && ng.get(0) != null);
    Double curFreq = ngrams.get(ng);
    if(curFreq == null) {
      // new n-gram
      curFreq = new Double(1.0);
    } else {
      // existing n-gram, just increment count
      curFreq++;
    }
    ngrams.put(ng, curFreq);

    // set 'n' if we haven't done so before
    if(n == 0) {
      n = ng.size();
    } else {
      if(n != ng.size()) {
        throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n'"
            + ", which usually is caused by a non-constant expression. Found '"+n+"' and '"
            + ng.size() + "'.");
      }
    }

    // Trim down the total number of n-grams if we've exceeded the maximum amount of memory allowed
    //
    // NOTE: Although 'k'*'pf' specifies the size of the estimation buffer, we don't want to keep
    //       performing N.log(N) trim operations each time the maximum hashmap size is exceeded.
    //       To handle this, we *actually* maintain an estimation buffer of size 2*'k'*'pf', and
    //       trim down to 'k'*'pf' whenever the hashmap size exceeds 2*'k'*'pf'. This really has
    //       a significant effect when 'k'*'pf' is very high.
    if(ngrams.size() > k * pf * 2) {
      trim(false);
    }
  }

  /**
   * Trims an n-gram estimation down to either 'pf' * 'k' n-grams, or 'k' n-grams if
   * finalTrim is true.
   */
  private void trim(boolean finalTrim) throws HiveException {
    ArrayList,Double>> list = new ArrayList(ngrams.entrySet());
    Collections.sort(list, new Comparator,Double>>() {
      public int compare(Map.Entry,Double> o1,
                         Map.Entry,Double> o2) {
        return o1.getValue().compareTo(o2.getValue());
      }
    });
    for(int i = 0; i < list.size() - (finalTrim ? k : pf*k); i++) {
      ngrams.remove( list.get(i).getKey() );
    }
  }

  /**
   * Takes a serialized n-gram estimator object created by the serialize() method and merges
   * it with the current n-gram object.
   *
   * @param other A serialized n-gram object created by the serialize() method
   */
  public void merge(List other) throws HiveException {
    if(other == null) {
      return;
    }

    // Get estimation parameters
    int otherK = Integer.parseInt(other.get(0).toString());
    int otherN = Integer.parseInt(other.get(1).toString());
    int otherPF = Integer.parseInt(other.get(2).toString());
    if(k > 0 && k != otherK) {
      throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'k'"
          + ", which usually is caused by a non-constant expression. Found '"+k+"' and '"
          + otherK + "'.");
    }
    if(n > 0 && otherN != n) {
      throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n'"
          + ", which usually is caused by a non-constant expression. Found '"+n+"' and '"
          + otherN + "'.");
    }
    if(pf > 0 && otherPF != pf) {
      throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'pf'"
          + ", which usually is caused by a non-constant expression. Found '"+pf+"' and '"
          + otherPF + "'.");
    }
    k = otherK;
    pf = otherPF;
    n = otherN;

    // Merge the other estimation into the current one
    for(int i = 3; i < other.size(); i++) {
      ArrayList key = new ArrayList();
      for(int j = 0; j < n; j++) {
        Text word = other.get(i+j);
        key.add(word.toString());
      }
      i += n;
      double val = Double.parseDouble( other.get(i).toString() );
      Double myval = ngrams.get(key);
      if(myval == null) {
        myval = new Double(val);
      } else {
        myval += val;
      }
      ngrams.put(key, myval);
    }

    trim(false);
  }


  /**
   * In preparation for a Hive merge() call, serializes the current n-gram estimator object into an
   * ArrayList of Text objects. This list is deserialized and merged by the
   * merge method.
   *
   * @return An ArrayList of Hadoop Text objects that represents the current
   * n-gram estimation.
   * @see #merge
   */
  public ArrayList serialize() throws HiveException {
    ArrayList result = new ArrayList();
    result.add(new Text(Integer.toString(k)));
    result.add(new Text(Integer.toString(n)));
    result.add(new Text(Integer.toString(pf)));
    for(Iterator > it = ngrams.keySet().iterator(); it.hasNext(); ) {
      ArrayList mykey = it.next();
      assert(mykey.size() > 0);
      for(int i = 0; i < mykey.size(); i++) {
        result.add(new Text(mykey.get(i)));
      }
      Double myval = ngrams.get(mykey);
      result.add(new Text(myval.toString()));
    }

    return result;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy