All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.udf.generic.GenericUDAFContextNGrams Maven / Gradle / Ivy

There is a newer version: 4.0.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Text;

/**
 * Estimates the top-k contextual n-grams in arbitrary sequential data using a heuristic.
 */
@Description(name = "context_ngrams",
    value = "_FUNC_(expr, array, k, pf) estimates the top-k most " +
      "frequent n-grams that fit into the specified context. The second parameter specifies " +
      "a string of words that specify the positions of the n-gram elements, with a null value " +
      "standing in for a 'blank' that must be filled by an n-gram element.",
    extended = "The primary expression must be an array of strings, or an array of arrays of " +
      "strings, such as the return type of the sentences() UDF. The second parameter specifies " +
      "the context -- for example, array(\"i\", \"love\", null) -- which would estimate the top " +
      "'k' words that follow the phrase \"i love\" in the primary expression. The optional " +
      "fourth parameter 'pf' controls the memory used by the heuristic. Larger values will " +
      "yield better accuracy, but use more memory. Example usage:\n" +
      "  SELECT context_ngrams(sentences(lower(review)), array(\"i\", \"love\", null, null), 10)" +
      " FROM movies\n" +
      "would attempt to determine the 10 most common two-word phrases that follow \"i love\" " +
      "in a database of free-form natural language movie reviews.")
public class GenericUDAFContextNGrams implements GenericUDAFResolver {
  static final Log LOG = LogFactory.getLog(GenericUDAFContextNGrams.class.getName());

  @Override
  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
    if (parameters.length != 3 && parameters.length != 4) {
      throw new UDFArgumentTypeException(parameters.length-1,
          "Please specify either three or four arguments.");
    }

    // Validate the first parameter, which is the expression to compute over. This should be an
    // array of strings type, or an array of arrays of strings.
    PrimitiveTypeInfo pti;
    if (parameters[0].getCategory() != ObjectInspector.Category.LIST) {
      throw new UDFArgumentTypeException(0,
          "Only list type arguments are accepted but "
          + parameters[0].getTypeName() + " was passed as parameter 1.");
    }
    switch (((ListTypeInfo) parameters[0]).getListElementTypeInfo().getCategory()) {
    case PRIMITIVE:
      // Parameter 1 was an array of primitives, so make sure the primitives are strings.
      pti = (PrimitiveTypeInfo) ((ListTypeInfo) parameters[0]).getListElementTypeInfo();
      break;

    case LIST:
      // Parameter 1 was an array of arrays, so make sure that the inner arrays contain
      // primitive strings.
      ListTypeInfo lti = (ListTypeInfo)
                         ((ListTypeInfo) parameters[0]).getListElementTypeInfo();
      pti = (PrimitiveTypeInfo) lti.getListElementTypeInfo();
      break;

    default:
      throw new UDFArgumentTypeException(0,
          "Only arrays of strings or arrays of arrays of strings are accepted but "
          + parameters[0].getTypeName() + " was passed as parameter 1.");
    }
    if(pti.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
      throw new UDFArgumentTypeException(0,
          "Only array or array> is allowed, but "
          + parameters[0].getTypeName() + " was passed as parameter 1.");
    }

    // Validate the second parameter, which should be an array of strings
    if(parameters[1].getCategory() != ObjectInspector.Category.LIST ||
       ((ListTypeInfo) parameters[1]).getListElementTypeInfo().getCategory() !=
         ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but "
          + parameters[1].getTypeName() + " was passed as parameter 2.");
    }
    if(((PrimitiveTypeInfo) ((ListTypeInfo)parameters[1]).getListElementTypeInfo()).
        getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
      throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but "
          + parameters[1].getTypeName() + " was passed as parameter 2.");
    }

    // Validate the third parameter, which should be an integer to represent 'k'
    if(parameters[2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(2, "Only integers are accepted but "
            + parameters[2].getTypeName() + " was passed as parameter 3.");
    }
    switch(((PrimitiveTypeInfo) parameters[2]).getPrimitiveCategory()) {
    case BYTE:
    case SHORT:
    case INT:
    case LONG:
    case TIMESTAMP:
      break;

    default:
      throw new UDFArgumentTypeException(2, "Only integers are accepted but "
            + parameters[2].getTypeName() + " was passed as parameter 3.");
    }

    // If the fourth parameter -- precision factor 'pf' -- has been specified, make sure it's
    // an integer.
    if(parameters.length == 4) {
      if(parameters[3].getCategory() != ObjectInspector.Category.PRIMITIVE) {
        throw new UDFArgumentTypeException(3, "Only integers are accepted but "
            + parameters[3].getTypeName() + " was passed as parameter 4.");
      }
      switch(((PrimitiveTypeInfo) parameters[3]).getPrimitiveCategory()) {
      case BYTE:
      case SHORT:
      case INT:
      case LONG:
      case TIMESTAMP:
        break;

      default:
        throw new UDFArgumentTypeException(3, "Only integers are accepted but "
            + parameters[3].getTypeName() + " was passed as parameter 4.");
      }
    }

    return new GenericUDAFContextNGramEvaluator();
  }

  /**
   * A constant-space heuristic to estimate the top-k contextual n-grams.
   */
  public static class GenericUDAFContextNGramEvaluator extends GenericUDAFEvaluator {
    // For PARTIAL1 and COMPLETE: ObjectInspectors for original data
    private transient ListObjectInspector outerInputOI;
    private transient StandardListObjectInspector innerInputOI;
    private transient ListObjectInspector contextListOI;
    private PrimitiveObjectInspector contextOI;
    private PrimitiveObjectInspector inputOI;
    private transient PrimitiveObjectInspector kOI;
    private transient PrimitiveObjectInspector pOI;

    // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations
    private transient ListObjectInspector loi;

    @Override
    public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
      super.init(m, parameters);

      // Init input object inspectors
      if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
        outerInputOI = (ListObjectInspector) parameters[0];
        if(outerInputOI.getListElementObjectInspector().getCategory() ==
            ObjectInspector.Category.LIST) {
          // We're dealing with input that is an array of arrays of strings
          innerInputOI = (StandardListObjectInspector) outerInputOI.getListElementObjectInspector();
          inputOI = (PrimitiveObjectInspector) innerInputOI.getListElementObjectInspector();
        } else {
          // We're dealing with input that is an array of strings
          inputOI = (PrimitiveObjectInspector) outerInputOI.getListElementObjectInspector();
          innerInputOI = null;
        }
        contextListOI = (ListObjectInspector) parameters[1];
        contextOI = (PrimitiveObjectInspector) contextListOI.getListElementObjectInspector();
        kOI = (PrimitiveObjectInspector) parameters[2];
        if(parameters.length == 4) {
          pOI = (PrimitiveObjectInspector) parameters[3];
        } else {
          pOI = null;
        }
      } else {
          // Init the list object inspector for handling partial aggregations
          loi = (ListObjectInspector) parameters[0];
      }

      // Init output object inspectors.
      //
      // The return type for a partial aggregation is still a list of strings.
      //
      // The return type for FINAL and COMPLETE is a full aggregation result, which is
      // an array of structures containing the n-gram and its estimated frequency.
      if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
        return ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.writableStringObjectInspector);
      } else {
        // Final return type that goes back to Hive: a list of structs with n-grams and their
        // estimated frequencies.
        ArrayList foi = new ArrayList();
        foi.add(ObjectInspectorFactory.getStandardListObjectInspector(
                  PrimitiveObjectInspectorFactory.writableStringObjectInspector));
        foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        ArrayList fname = new ArrayList();
        fname.add("ngram");
        fname.add("estfrequency");
        return ObjectInspectorFactory.getStandardListObjectInspector(
                 ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi) );
      }
    }

    @Override
    public void merge(AggregationBuffer agg, Object obj) throws HiveException {
      if(obj == null) {
        return;
      }
      NGramAggBuf myagg = (NGramAggBuf) agg;
      List partial = (List) loi.getList(obj);

      // remove the context words from the end of the list
      int contextSize = Integer.parseInt( partial.get(partial.size()-1).toString() );
      partial.remove(partial.size()-1);
      if(myagg.context.size() > 0)  {
        if(contextSize != myagg.context.size()) {
          throw new HiveException(getClass().getSimpleName() + ": found a mismatch in the" +
              " context string lengths. This is usually caused by passing a non-constant" +
              " expression for the context.");
        }
      } else {
        for(int i = partial.size()-contextSize; i < partial.size(); i++) {
          String word = partial.get(i).toString();
          if(word.equals("")) {
            myagg.context.add( null );
          } else {
            myagg.context.add( word );
          }
        }
        partial.subList(partial.size()-contextSize, partial.size()).clear();
        myagg.nge.merge(partial);
      }
    }

    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
      NGramAggBuf myagg = (NGramAggBuf) agg;
      ArrayList result = myagg.nge.serialize();

      // push the context on to the end of the serialized n-gram estimation
      for(int i = 0; i < myagg.context.size(); i++) {
        if(myagg.context.get(i) == null) {
          result.add(new Text(""));
        } else {
          result.add(new Text(myagg.context.get(i)));
        }
      }
      result.add(new Text(Integer.toString(myagg.context.size())));

      return result;
    }

    // Finds all contextual n-grams in a sequence of words, and passes the n-grams to the
    // n-gram estimator object
    private void processNgrams(NGramAggBuf agg, ArrayList seq) throws HiveException {
      // generate n-grams wherever the context matches
      assert(agg.context.size() > 0);
      ArrayList ng = new ArrayList();
      for(int i = seq.size() - agg.context.size(); i >= 0; i--) {
        // check if the context matches
        boolean contextMatches = true;
        ng.clear();
        for(int j = 0; j < agg.context.size(); j++) {
          String contextWord = agg.context.get(j);
          if(contextWord == null) {
            ng.add(seq.get(i+j));
          } else {
            if(!contextWord.equals(seq.get(i+j))) {
              contextMatches = false;
              break;
            }
          }
        }

        // add to n-gram estimation only if the context matches
        if(contextMatches) {
          agg.nge.add(ng);
          ng = new ArrayList();
        }
      }
    }

    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
      assert (parameters.length == 3 || parameters.length == 4);
      if(parameters[0] == null || parameters[1] == null || parameters[2] == null) {
        return;
      }
      NGramAggBuf myagg = (NGramAggBuf) agg;

      // Parse out the context and 'k' if we haven't already done so, and while we're at it,
      // also parse out the precision factor 'pf' if the user has supplied one.
      if(!myagg.nge.isInitialized()) {
        int k = PrimitiveObjectInspectorUtils.getInt(parameters[2], kOI);
        int pf = 0;
        if(k < 1) {
          throw new HiveException(getClass().getSimpleName() + " needs 'k' to be at least 1, "
                                  + "but you supplied " + k);
        }
        if(parameters.length == 4) {
          pf = PrimitiveObjectInspectorUtils.getInt(parameters[3], pOI);
          if(pf < 1) {
            throw new HiveException(getClass().getSimpleName() + " needs 'pf' to be at least 1, "
                + "but you supplied " + pf);
          }
        } else {
          pf = 1; // placeholder; minimum pf value is enforced in NGramEstimator
        }

        // Parse out the context and make sure it isn't empty
        myagg.context.clear();
        List context = (List) contextListOI.getList(parameters[1]);
        int contextNulls = 0;
        for(int i = 0; i < context.size(); i++) {
          String word = PrimitiveObjectInspectorUtils.getString(context.get(i), contextOI);
          if(word == null) {
            contextNulls++;
          }
          myagg.context.add(word);
        }
        if(context.size() == 0) {
          throw new HiveException(getClass().getSimpleName() + " needs a context array " +
            "with at least one element.");
        }
        if(contextNulls == 0) {
          throw new HiveException(getClass().getSimpleName() + " the context array needs to " +
            "contain at least one 'null' value to indicate what should be counted.");
        }

        // Set parameters in the n-gram estimator object
        myagg.nge.initialize(k, pf, contextNulls);
      }

      // get the input expression
      List outer = (List) outerInputOI.getList(parameters[0]);
      if(innerInputOI != null) {
        // we're dealing with an array of arrays of strings
        for(int i = 0; i < outer.size(); i++) {
          List inner = (List) innerInputOI.getList(outer.get(i));
          ArrayList words = new ArrayList();
          for(int j = 0; j < inner.size(); j++) {
            String word = PrimitiveObjectInspectorUtils.getString(inner.get(j), inputOI);
            words.add(word);
          }

          // parse out n-grams, update frequency counts
          processNgrams(myagg, words);
        }
      } else {
        // we're dealing with an array of strings
        ArrayList words = new ArrayList();
        for(int i = 0; i < outer.size(); i++) {
          String word = PrimitiveObjectInspectorUtils.getString(outer.get(i), inputOI);
          words.add(word);
        }

        // parse out n-grams, update frequency counts
        processNgrams(myagg, words);
      }
    }

    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
      NGramAggBuf myagg = (NGramAggBuf) agg;
      return myagg.nge.getNGrams();
    }


    // Aggregation buffer methods.
    static class NGramAggBuf extends AbstractAggregationBuffer {
      ArrayList context;
      NGramEstimator nge;
    };

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
      NGramAggBuf result = new NGramAggBuf();
      result.nge = new NGramEstimator();
      result.context = new ArrayList();
      reset(result);
      return result;
    }

    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
      NGramAggBuf result = (NGramAggBuf) agg;
      result.context.clear();
      result.nge.reset();
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy