All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.udf.generic.GenericUDAFnGrams Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.List;

import com.facebook.presto.hive.$internal.org.apache.commons.logging.Log;
import com.facebook.presto.hive.$internal.org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Text;

/**
 * Estimates the top-k n-grams in arbitrary sequential data using a heuristic.
 */
@Description(name = "ngrams",
    value = "_FUNC_(expr, n, k, pf) - Estimates the top-k n-grams in rows that consist of "
            + "sequences of strings, represented as arrays of strings, or arrays of arrays of "
            + "strings. 'pf' is an optional precision factor that controls memory usage.",
    extended = "The parameter 'n' specifies what type of n-grams are being estimated. Unigrams "
             + "are n = 1, and bigrams are n = 2. Generally, n will not be greater than about 5. "
             + "The 'k' parameter specifies how many of the highest-frequency n-grams will be "
             + "returned by the UDAF. The optional precision factor 'pf' specifies how much "
             + "memory to use for estimation; more memory will give more accurate frequency "
             + "counts, but could crash the JVM. The default value is 20, which internally "
             + "maintains 20*k n-grams, but only returns the k highest frequency ones. "
             + "The output is an array of structs with the top-k n-grams. It might be convenient "
             + "to explode() the output of this UDAF.")
public class GenericUDAFnGrams implements GenericUDAFResolver {
  static final Log LOG = LogFactory.getLog(GenericUDAFnGrams.class.getName());

  @Override
  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
    if (parameters.length != 3 && parameters.length != 4) {
      throw new UDFArgumentTypeException(parameters.length-1,
          "Please specify either three or four arguments.");
    }

    // Validate the first parameter, which is the expression to compute over. This should be an
    // array of strings type, or an array of arrays of strings.
    PrimitiveTypeInfo pti;
    if (parameters[0].getCategory() != ObjectInspector.Category.LIST) {
      throw new UDFArgumentTypeException(0,
          "Only list type arguments are accepted but "
          + parameters[0].getTypeName() + " was passed as parameter 1.");
    }
    switch (((ListTypeInfo) parameters[0]).getListElementTypeInfo().getCategory()) {
    case PRIMITIVE:
      // Parameter 1 was an array of primitives, so make sure the primitives are strings.
      pti = (PrimitiveTypeInfo) ((ListTypeInfo) parameters[0]).getListElementTypeInfo();
      break;

    case LIST:
      // Parameter 1 was an array of arrays, so make sure that the inner arrays contain
      // primitive strings.
      ListTypeInfo lti = (ListTypeInfo)
                         ((ListTypeInfo) parameters[0]).getListElementTypeInfo();
      pti = (PrimitiveTypeInfo) lti.getListElementTypeInfo();
      break;

    default:
      throw new UDFArgumentTypeException(0,
          "Only arrays of strings or arrays of arrays of strings are accepted but "
          + parameters[0].getTypeName() + " was passed as parameter 1.");
    }
    if(pti.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
      throw new UDFArgumentTypeException(0,
          "Only array or array> is allowed, but "
          + parameters[0].getTypeName() + " was passed as parameter 1.");
    }

    // Validate the second parameter, which should be an integer
    if(parameters[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(1, "Only integers are accepted but "
          + parameters[1].getTypeName() + " was passed as parameter 2.");
    }
    switch(((PrimitiveTypeInfo) parameters[1]).getPrimitiveCategory()) {
    case BYTE:
    case SHORT:
    case INT:
    case LONG:
    case TIMESTAMP:
      break;

    default:
      throw new UDFArgumentTypeException(1, "Only integers are accepted but "
          + parameters[1].getTypeName() + " was passed as parameter 2.");
    }

    // Validate the third parameter, which should also be an integer
    if(parameters[2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(2, "Only integers are accepted but "
            + parameters[2].getTypeName() + " was passed as parameter 3.");
    }
    switch(((PrimitiveTypeInfo) parameters[2]).getPrimitiveCategory()) {
    case BYTE:
    case SHORT:
    case INT:
    case LONG:
    case TIMESTAMP:
      break;

    default:
      throw new UDFArgumentTypeException(2, "Only integers are accepted but "
            + parameters[2].getTypeName() + " was passed as parameter 3.");
    }

    // If we have the optional fourth parameter, make sure it's also an integer
    if(parameters.length == 4) {
      if(parameters[3].getCategory() != ObjectInspector.Category.PRIMITIVE) {
        throw new UDFArgumentTypeException(3, "Only integers are accepted but "
            + parameters[3].getTypeName() + " was passed as parameter 4.");
      }
      switch(((PrimitiveTypeInfo) parameters[3]).getPrimitiveCategory()) {
      case BYTE:
      case SHORT:
      case INT:
      case LONG:
      case TIMESTAMP:
        break;

      default:
        throw new UDFArgumentTypeException(3, "Only integers are accepted but "
            + parameters[3].getTypeName() + " was passed as parameter 4.");
      }
    }

    return new GenericUDAFnGramEvaluator();
  }

  /**
   * A constant-space heuristic to estimate the top-k n-grams.
   */
  public static class GenericUDAFnGramEvaluator extends GenericUDAFEvaluator {
    // For PARTIAL1 and COMPLETE: ObjectInspectors for original data
    private transient ListObjectInspector outerInputOI;
    private transient StandardListObjectInspector innerInputOI;
    private transient PrimitiveObjectInspector inputOI;
    private transient PrimitiveObjectInspector nOI;
    private transient PrimitiveObjectInspector kOI;
    private transient PrimitiveObjectInspector pOI;

    // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations
    private transient ListObjectInspector loi;

    @Override
    public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
      super.init(m, parameters);

      // Init input object inspectors
      if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
        outerInputOI = (ListObjectInspector) parameters[0];
        if(outerInputOI.getListElementObjectInspector().getCategory() ==
            ObjectInspector.Category.LIST) {
          // We're dealing with input that is an array of arrays of strings
          innerInputOI = (StandardListObjectInspector) outerInputOI.getListElementObjectInspector();
          inputOI = (PrimitiveObjectInspector) innerInputOI.getListElementObjectInspector();
        } else {
          // We're dealing with input that is an array of strings
          inputOI = (PrimitiveObjectInspector) outerInputOI.getListElementObjectInspector();
          innerInputOI = null;
        }
        nOI = (PrimitiveObjectInspector) parameters[1];
        kOI = (PrimitiveObjectInspector) parameters[2];
        if(parameters.length == 4) {
          pOI = (PrimitiveObjectInspector) parameters[3];
        } else {
          pOI = null;
        }
      } else {
          // Init the list object inspector for handling partial aggregations
          loi = (ListObjectInspector) parameters[0];
      }

      // Init output object inspectors.
      //
      // The return type for a partial aggregation is still a list of strings.
      //
      // The return type for FINAL and COMPLETE is a full aggregation result, which is
      // an array of structures containing the n-gram and its estimated frequency.
      if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
        return ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.writableStringObjectInspector);
      } else {
        // Final return type that goes back to Hive: a list of structs with n-grams and their
        // estimated frequencies.
        ArrayList foi = new ArrayList();
        foi.add(ObjectInspectorFactory.getStandardListObjectInspector(
                  PrimitiveObjectInspectorFactory.writableStringObjectInspector));
        foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        ArrayList fname = new ArrayList();
        fname.add("ngram");
        fname.add("estfrequency");
        return ObjectInspectorFactory.getStandardListObjectInspector(
                 ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi) );
      }
    }

    @Override
    public void merge(AggregationBuffer agg, Object partial) throws HiveException {
      if(partial == null) {
        return;
      }
      NGramAggBuf myagg = (NGramAggBuf) agg;
      List partialNGrams = (List) loi.getList(partial);
      int n = Integer.parseInt(partialNGrams.get(partialNGrams.size()-1).toString());

      // A value of 0 for n indicates that the mapper processed data that does not meet
      // filter criteria, so merge() should be NO-OP.
      if (n == 0) {
        return;
      }

      if(myagg.n > 0 && myagg.n != n) {
        throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n'"
            + ", which usually is caused by a non-constant expression. Found '"+n+"' and '"
            + myagg.n + "'.");
      }
      myagg.n = n;
      partialNGrams.remove(partialNGrams.size()-1);
      myagg.nge.merge(partialNGrams);
    }

    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
      NGramAggBuf myagg = (NGramAggBuf) agg;
      ArrayList result = myagg.nge.serialize();
      result.add(new Text(Integer.toString(myagg.n)));
      return result;
    }

    private void processNgrams(NGramAggBuf agg, ArrayList seq) throws HiveException {
      for(int i = seq.size()-agg.n; i >= 0; i--) {
        ArrayList ngram = new ArrayList();
        for(int j = 0; j < agg.n; j++)  {
          ngram.add(seq.get(i+j));
        }
        agg.nge.add(ngram);
      }
    }

    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
      assert (parameters.length == 3 || parameters.length == 4);
      if(parameters[0] == null || parameters[1] == null || parameters[2] == null) {
        return;
      }
      NGramAggBuf myagg = (NGramAggBuf) agg;

      // Parse out 'n' and 'k' if we haven't already done so, and while we're at it,
      // also parse out the precision factor 'pf' if the user has supplied one.
      if(!myagg.nge.isInitialized()) {
        int n = PrimitiveObjectInspectorUtils.getInt(parameters[1], nOI);
        int k = PrimitiveObjectInspectorUtils.getInt(parameters[2], kOI);
        int pf = 0;
        if(n < 1) {
          throw new HiveException(getClass().getSimpleName() + " needs 'n' to be at least 1, "
                                  + "but you supplied " + n);
        }
        if(k < 1) {
          throw new HiveException(getClass().getSimpleName() + " needs 'k' to be at least 1, "
                                  + "but you supplied " + k);
        }
        if(parameters.length == 4) {
          pf = PrimitiveObjectInspectorUtils.getInt(parameters[3], pOI);
          if(pf < 1) {
            throw new HiveException(getClass().getSimpleName() + " needs 'pf' to be at least 1, "
                + "but you supplied " + pf);
          }
        } else {
          pf = 1; // placeholder; minimum pf value is enforced in NGramEstimator
        }

        // Set the parameters
        myagg.n = n;
        myagg.nge.initialize(k, pf, n);
      }

      // get the input expression
      List outer = (List) outerInputOI.getList(parameters[0]);
      if(innerInputOI != null) {
        // we're dealing with an array of arrays of strings
        for(int i = 0; i < outer.size(); i++) {
          List inner = (List) innerInputOI.getList(outer.get(i));
          ArrayList words = new ArrayList();
          for(int j = 0; j < inner.size(); j++) {
            String word = PrimitiveObjectInspectorUtils.getString(inner.get(j), inputOI);
            words.add(word);
          }

          // parse out n-grams, update frequency counts
          processNgrams(myagg, words);
        }
      } else {
        // we're dealing with an array of strings
        ArrayList words = new ArrayList();
        for(int i = 0; i < outer.size(); i++) {
          String word = PrimitiveObjectInspectorUtils.getString(outer.get(i), inputOI);
          words.add(word);
        }

        // parse out n-grams, update frequency counts
        processNgrams(myagg, words);
      }
    }

    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
      NGramAggBuf myagg = (NGramAggBuf) agg;
      return myagg.nge.getNGrams();
    }

    // Aggregation buffer methods.
    static class NGramAggBuf extends AbstractAggregationBuffer {
      NGramEstimator nge;
      int n;
    };

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
      NGramAggBuf result = new NGramAggBuf();
      result.nge = new NGramEstimator();
      reset(result);
      return result;
    }

    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
      NGramAggBuf result = (NGramAggBuf) agg;
      result.nge.reset();
      result.n = 0;
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy