All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.udf.generic.GenericUDFExtractUnion Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.SettableListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.SettableMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.UnionObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector;

import io.prestosql.hive.$internal.com.google.common.annotations.VisibleForTesting;

@Description(
    name = "extract_union",
    value = "_FUNC_(union[, tag])" + " - Recursively explodes unions into structs or simply extracts the given tag.",
    extended = "  > SELECT _FUNC_({0:\"foo\"}).tag_0 FROM src;\n  foo\n"
        + "  > SELECT _FUNC_({0:\"foo\"}).tag_1 FROM src;\n  null\n"
        + "  > SELECT _FUNC_({0:\"foo\"}, 0) FROM src;\n  foo\n"
        + "  > SELECT _FUNC_({0:\"foo\"}, 1) FROM src;\n  null")
public class GenericUDFExtractUnion extends GenericUDF {

  private static final int ALL_TAGS = -1;

  private final ObjectInspectorConverter objectInspectorConverter;
  private final ValueConverter valueConverter;

  private int tag = ALL_TAGS;
  private UnionObjectInspector unionOI;
  private ObjectInspector sourceOI;

  public GenericUDFExtractUnion() {
    this(new ObjectInspectorConverter(), new ValueConverter());
  }

  @VisibleForTesting
  GenericUDFExtractUnion(ObjectInspectorConverter objectInspectorConverter, ValueConverter valueConverter) {
    this.objectInspectorConverter = objectInspectorConverter;
    this.valueConverter = valueConverter;
  }

  @Override
  public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
    if (arguments.length == 1) {
      sourceOI = arguments[0];
      return objectInspectorConverter.convert(sourceOI);
    }

    if (arguments.length == 2 && (arguments[0] instanceof UnionObjectInspector)
        && (arguments[1] instanceof WritableConstantIntObjectInspector)) {
      tag = ((WritableConstantIntObjectInspector) arguments[1]).getWritableConstantValue().get();
      unionOI = (UnionObjectInspector) arguments[0];
      List fieldOIs = ((UnionObjectInspector) arguments[0]).getObjectInspectors();
      if (tag < 0 || tag >= fieldOIs.size()) {
        throw new UDFArgumentException(
            "int constant must be a valid union tag for " + unionOI.getTypeName() + ". Expected 0-"
                + (fieldOIs.size() - 1) + " got: " + tag);
      }
      return fieldOIs.get(tag);
    }

    String argumentTypes = "nothing";
    if (arguments.length > 0) {
      List typeNames = new ArrayList<>();
      for (ObjectInspector oi : arguments) {
        typeNames.add(oi.getTypeName());
      }
      argumentTypes = typeNames.toString();
    }
    throw new UDFArgumentException(
        "Unsupported arguments. Expected a type containing a union or a union and an int constant, got: "
            + argumentTypes);
  }

  @Override
  public Object evaluate(DeferredObject[] arguments) throws HiveException {
    Object value = arguments[0].get();
    if (tag == ALL_TAGS) {
      return valueConverter.convert(value, sourceOI);
    } else if (tag == unionOI.getTag(value)) {
      return unionOI.getField(value);
    } else {
      return null;
    }
  }

  @Override
  public String getDisplayString(String[] children) {
    StringBuilder sb = new StringBuilder();
    sb.append("extract_union(");
    for (int i = 0; i < children.length; i++) {
      if (i > 0) {
        sb.append(',');
      }
      sb.append(children[i]);
    }
    sb.append(')');
    return sb.toString();
  }

  static class ObjectInspectorConverter {

    private static final String TAG_FIELD_PREFIX = "tag_";

    ObjectInspector convert(ObjectInspector inspector) {
      AtomicBoolean foundUnion = new AtomicBoolean(false);
      ObjectInspector result = convert(inspector, foundUnion);
      if (!foundUnion.get()) {
        throw new IllegalArgumentException("No unions found in " + inspector.getTypeName());
      }
      return result;
    }

    private ObjectInspector convert(ObjectInspector inspector, AtomicBoolean foundUnion) {
      Category category = inspector.getCategory();
      switch (category) {
      case PRIMITIVE:
        return inspector;
      case LIST:
        return convertList(inspector, foundUnion);
      case MAP:
        return convertMap(inspector, foundUnion);
      case STRUCT:
        return convertStruct(inspector, foundUnion);
      case UNION:
        foundUnion.set(true);
        return convertUnion(inspector, foundUnion);
      default:
        throw new IllegalStateException("Unknown category: " + category);
      }
    }

    private ObjectInspector convertList(ObjectInspector inspector, AtomicBoolean foundUnion) {
      ListObjectInspector listOI = (ListObjectInspector) inspector;
      ObjectInspector elementOI = convert(listOI.getListElementObjectInspector(), foundUnion);
      return ObjectInspectorFactory.getStandardListObjectInspector(elementOI);
    }

    private ObjectInspector convertMap(ObjectInspector inspector, AtomicBoolean foundUnion) {
      MapObjectInspector mapOI = (MapObjectInspector) inspector;
      ObjectInspector keyOI = convert(mapOI.getMapKeyObjectInspector(), foundUnion);
      ObjectInspector valueOI = convert(mapOI.getMapValueObjectInspector(), foundUnion);
      return ObjectInspectorFactory.getStandardMapObjectInspector(keyOI, valueOI);
    }

    private ObjectInspector convertStruct(ObjectInspector inspector, AtomicBoolean foundUnion) {
      StructObjectInspector structOI = (StructObjectInspector) inspector;
      List fields = structOI.getAllStructFieldRefs();
      List names = new ArrayList<>(fields.size());
      List inspectors = new ArrayList<>(fields.size());
      for (StructField field : fields) {
        names.add(field.getFieldName());
        inspectors.add(convert(field.getFieldObjectInspector(), foundUnion));
      }
      return ObjectInspectorFactory.getStandardStructObjectInspector(names, inspectors);
    }

    private ObjectInspector convertUnion(ObjectInspector inspector, AtomicBoolean foundUnion) {
      UnionObjectInspector unionOI = (UnionObjectInspector) inspector;
      List fieldOIs = unionOI.getObjectInspectors();
      int tags = fieldOIs.size();
      List names = new ArrayList<>(tags);
      List inspectors = new ArrayList<>(tags);
      for (int i = 0; i < tags; i++) {
        names.add(TAG_FIELD_PREFIX + i);
        inspectors.add(convert(fieldOIs.get(i), foundUnion));
      }
      return ObjectInspectorFactory.getStandardStructObjectInspector(names, inspectors);
    }

  }

  static class ValueConverter {

    Object convert(Object value, ObjectInspector inspector) {
      Category category = inspector.getCategory();
      switch (category) {
      case PRIMITIVE:
        return value;
      case LIST:
        return convertList(value, inspector);
      case MAP:
        return convertMap(value, inspector);
      case STRUCT:
        return convertStruct(value, inspector);
      case UNION:
        return convertUnion(value, inspector);
      default:
        throw new IllegalStateException("Unknown category: " + category);
      }
    }

    private Object convertList(Object list, ObjectInspector inspector) {
      SettableListObjectInspector listOI = (SettableListObjectInspector) inspector;
      int size = listOI.getListLength(list);
      Object result = listOI.create(size);
      for (int i = 0; i < size; i++) {
        listOI.set(result, i, convert(listOI.getListElement(list, i), listOI.getListElementObjectInspector()));
      }
      return result;
    }

    private Object convertMap(Object map, ObjectInspector inspector) {
      SettableMapObjectInspector mapOI = (SettableMapObjectInspector) inspector;
      Object result = mapOI.create();
      for (Object key : mapOI.getMap(map).keySet()) {
        Object value = mapOI.getMapValueElement(map, key);
        mapOI.put(
            result,
            convert(key, mapOI.getMapKeyObjectInspector()),
            convert(value, mapOI.getMapValueObjectInspector()));
      }
      return result;
    }

    private Object convertStruct(Object struct, ObjectInspector inspector) {
      SettableStructObjectInspector structOI = (SettableStructObjectInspector) inspector;
      Object result = structOI.create();
      for (StructField field : structOI.getAllStructFieldRefs()) {
        Object value = structOI.getStructFieldData(struct, field);
        structOI.setStructFieldData(result, field, convert(value, field.getFieldObjectInspector()));
      }
      return result;
    }

    private Object convertUnion(Object union, ObjectInspector inspector) {
      UnionObjectInspector unionOI = (UnionObjectInspector) inspector;
      List childOIs = unionOI.getObjectInspectors();
      byte tag = unionOI.getTag(union);
      Object value = unionOI.getField(union);
      List result = new ArrayList<>(childOIs.size());
      for (int i = 0; i < childOIs.size(); i++) {
        if (i == tag) {
          result.add(convert(value, childOIs.get(i)));
        } else {
          result.add(null);
        }
      }
      return result;
    }
  }

}