All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.udf.generic.GenericUDFLeadLag Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.udf.generic;

import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.PTFPartition.PTFPartitionIterator;
import org.apache.hadoop.hive.ql.exec.PTFUtils;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.io.IntWritable;

public abstract class GenericUDFLeadLag extends GenericUDF {
  transient ExprNodeEvaluator exprEvaluator;
  transient PTFPartitionIterator pItr;
  transient ObjectInspector firstArgOI;
  transient ObjectInspector defaultArgOI;
  transient Converter defaultValueConverter;
  int amt;

  static {
    PTFUtils.makeTransient(GenericUDFLeadLag.class, "exprEvaluator", "pItr", "firstArgOI",
            "defaultArgOI", "defaultValueConverter");
  }

  @Override
  public Object evaluate(DeferredObject[] arguments) throws HiveException {
    Object defaultVal = null;
    if (arguments.length == 3) {
      defaultVal = ObjectInspectorUtils.copyToStandardObject(
              defaultValueConverter.convert(arguments[2].get()), defaultArgOI);
    }

    int idx = pItr.getIndex() - 1;
    int start = 0;
    int end = pItr.getPartition().size();
    try {
      Object ret = null;
      int newIdx = getIndex(amt);

      if (newIdx >= end || newIdx < start) {
        ret = defaultVal;
      } else {
        Object row = getRow(amt);
        ret = exprEvaluator.evaluate(row);
        ret = ObjectInspectorUtils.copyToStandardObject(ret, firstArgOI,
                ObjectInspectorCopyOption.WRITABLE);
      }
      return ret;
    } finally {
      Object currRow = pItr.resetToIndex(idx);
      // reevaluate expression on current Row, to trigger the Lazy object
      // caches to be reset to the current row.
      exprEvaluator.evaluate(currRow);
    }

  }

  @Override
  public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
    if (!(arguments.length >= 1 && arguments.length <= 3)) {
      throw new UDFArgumentTypeException(arguments.length - 1, "Incorrect invocation of "
              + _getFnName() + ": _FUNC_(expr, amt, default)");
    }

    amt = 1;
    if (arguments.length > 1) {
      ObjectInspector amtOI = arguments[1];
      if (!ObjectInspectorUtils.isConstantObjectInspector(amtOI)
              || (amtOI.getCategory() != ObjectInspector.Category.PRIMITIVE)
              || ((PrimitiveObjectInspector) amtOI).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.INT) {
        throw new UDFArgumentTypeException(1, _getFnName() + " amount must be a integer value "
                + amtOI.getTypeName() + " was passed as parameter 1.");
      }
      Object o = ((ConstantObjectInspector) amtOI).getWritableConstantValue();
      amt = ((IntWritable) o).get();
      if (amt < 0) {
        throw new UDFArgumentTypeException(1,  " amount can not be nagative. Specified: " + amt);
      }
    }

    if (arguments.length == 3) {
      defaultArgOI = arguments[2];
      ObjectInspectorConverters.getConverter(arguments[2], arguments[0]);
      defaultValueConverter = ObjectInspectorConverters.getConverter(arguments[2], arguments[0]);

    }

    firstArgOI = arguments[0];
    return ObjectInspectorUtils.getStandardObjectInspector(firstArgOI,
            ObjectInspectorCopyOption.WRITABLE);
  }

  public ExprNodeEvaluator getExprEvaluator() {
    return exprEvaluator;
  }

  public void setExprEvaluator(ExprNodeEvaluator exprEvaluator) {
    this.exprEvaluator = exprEvaluator;
  }

  public PTFPartitionIterator getpItr() {
    return pItr;
  }

  public void setpItr(PTFPartitionIterator pItr) {
    this.pItr = pItr;
  }

  public ObjectInspector getFirstArgOI() {
    return firstArgOI;
  }

  public void setFirstArgOI(ObjectInspector firstArgOI) {
    this.firstArgOI = firstArgOI;
  }

  public ObjectInspector getDefaultArgOI() {
    return defaultArgOI;
  }

  public void setDefaultArgOI(ObjectInspector defaultArgOI) {
    this.defaultArgOI = defaultArgOI;
  }

  public Converter getDefaultValueConverter() {
    return defaultValueConverter;
  }

  public void setDefaultValueConverter(Converter defaultValueConverter) {
    this.defaultValueConverter = defaultValueConverter;
  }

  public int getAmt() {
    return amt;
  }

  public void setAmt(int amt) {
    this.amt = amt;
  }

  @Override
  public String getDisplayString(String[] children) {
    assert (children.length == 2);
    return getStandardDisplayString(_getFnName(), children);
  }

  protected abstract String _getFnName();

  protected abstract Object getRow(int amt) throws HiveException;

  protected abstract int getIndex(int amt);

}