All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.udf.generic;

import java.util.HashSet;
import java.util.Set;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ReturnObjectInspectorResolver;
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.BooleanWritable;

import com.esotericsoftware.minlog.Log;

/**
 * GenericUDFIn
 *
 * Example usage:
 * SELECT key FROM src WHERE key IN ("238", "1");
 *
 * From MySQL page on IN(): To comply with the SQL standard, IN returns NULL
 * not only if the expression on the left hand side is NULL, but also if no
 * match is found in the list and one of the expressions in the list is NULL.
 *
 * Also noteworthy: type conversion behavior is different from MySQL. With
 * expr IN expr1, expr2... in MySQL, exprN will each be converted into the same
 * type as expr. In the Hive implementation, all expr(N) will be converted into
 * a common type for conversion consistency with other UDF's, and to prevent
 * conversions from a big type to a small type (e.g. int to tinyint)
 */
@Description(name = "in",
    value = "test _FUNC_(val1, val2...) - returns true if test equals any valN ")

public class GenericUDFIn extends GenericUDF {

  private transient ObjectInspector[] argumentOIs;
  // this set is a copy of the arguments objects - avoid serializing
  private transient Set constantInSet;
  private boolean isInSetConstant = true; //are variables from IN(...) constant

  private final BooleanWritable bw = new BooleanWritable();

  private transient ReturnObjectInspectorResolver conversionHelper;
  private transient ObjectInspector compareOI;

  @Override
  public ObjectInspector initialize(ObjectInspector[] arguments)
      throws UDFArgumentException {
    if (arguments.length < 2) {
      throw new UDFArgumentLengthException(
          "The function IN requires at least two arguments, got "
          + arguments.length);
    }
    argumentOIs = arguments;

    // We want to use the ReturnObjectInspectorResolver because otherwise
    // ObjectInspectorUtils.compare() will return != for two objects that have
    // different object inspectors, e.g. 238 and "238". The ROIR will help convert
    // both values to a common type so that they can be compared reasonably.
    conversionHelper = new GenericUDFUtils.ReturnObjectInspectorResolver(true);

    for (ObjectInspector oi : arguments) {
      if(!conversionHelper.update(oi)) {
        StringBuilder sb = new StringBuilder();
        sb.append("The arguments for IN should be the same type! Types are: {");
        sb.append(arguments[0].getTypeName());
        sb.append(" IN (");
        for(int i=1; i();
    if (compareOI.getCategory().equals(ObjectInspector.Category.PRIMITIVE)) {
      for (int i = 1; i < arguments.length; ++i) {
        constantInSet.add(((PrimitiveObjectInspector) compareOI)
            .getPrimitiveJavaObject(conversionHelper
                .convertIfNecessary(arguments[i].get(), argumentOIs[i])));
      }
    } else {
      for (int i = 1; i < arguments.length; ++i) {
        constantInSet.add(((ConstantObjectInspector) argumentOIs[i]).getWritableConstantValue());
      }
    }
  }

  @Override
  public Object evaluate(DeferredObject[] arguments) throws HiveException {
    bw.set(false);

    if (arguments[0].get() == null) {
      return null;
    }

    if (isInSetConstant) {
      if (constantInSet == null) {
        prepareInSet(arguments);
      }
      switch (compareOI.getCategory()) {
      case PRIMITIVE: {
        if (constantInSet.contains(((PrimitiveObjectInspector) compareOI)
            .getPrimitiveJavaObject(conversionHelper.convertIfNecessary(arguments[0].get(),
                argumentOIs[0])))) {
          bw.set(true);
          return bw;
        }
        break;
      }
      case LIST: {
        if (constantInSet.contains(((ListObjectInspector) compareOI).getList(conversionHelper
            .convertIfNecessary(arguments[0].get(), argumentOIs[0])))) {
          bw.set(true);
          return bw;
        }
        break;
      }
      case MAP: {
        if (constantInSet.contains(((MapObjectInspector) compareOI).getMap(conversionHelper
            .convertIfNecessary(arguments[0].get(), argumentOIs[0])))) {
          bw.set(true);
          return bw;
        }
        break;
      }
      case STRUCT: {
        if (constantInSet.contains(((StructObjectInspector) compareOI).getStructFieldsDataAsList(conversionHelper
           .convertIfNecessary(arguments[0].get(), argumentOIs[0])))) {
          bw.set(true);
          return bw;
        }
        break;
      }
      default:
        throw new RuntimeException("Compare of unsupported constant type: "
            + compareOI.getCategory());
      }
      if (constantInSet.contains(null)) {
        return null;
      }
    } else {
      for (int i = 1; i < arguments.length; i++) {
        if (ObjectInspectorUtils.compare(
            conversionHelper.convertIfNecessary(
                arguments[0].get(), argumentOIs[0]), compareOI,
            conversionHelper.convertIfNecessary(
                arguments[i].get(), argumentOIs[i], false), compareOI) == 0) {
          bw.set(true);
          return bw;
        }
      }
      // Nothing matched. See comment at top.
      for (int i = 1; i < arguments.length; i++) {
        if (arguments[i].get() == null) {
          return null;
        }
      }
    }
    return bw;
  }

  @Override
  public String getDisplayString(String[] children) {
    assert (children.length >= 2);
    StringBuilder sb = new StringBuilder();

    sb.append("(");
    sb.append(children[0]);
    sb.append(") ");
    sb.append("IN (");
    for(int i=1; i