org.apache.hadoop.hive.ql.udf.generic.GenericUDAFLeadLag Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.udf.generic;
import com.facebook.presto.hive.$internal.org.slf4j.Logger;
import com.facebook.presto.hive.$internal.org.slf4j.LoggerFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.io.IntWritable;
/**
* abstract class for Lead & lag UDAFs GenericUDAFLeadLag.
*
*/
public abstract class GenericUDAFLeadLag extends AbstractGenericUDAFResolver {
static final Logger LOG = LoggerFactory.getLogger(GenericUDAFLead.class.getName());
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo parameters)
throws SemanticException {
ObjectInspector[] paramOIs = parameters.getParameterObjectInspectors();
String fNm = functionName();
if (!(paramOIs.length >= 1 && paramOIs.length <= 3)) {
throw new UDFArgumentTypeException(paramOIs.length - 1, "Incorrect invocation of " + fNm
+ ": _FUNC_(expr, amt, default)");
}
int amt = 1;
if (paramOIs.length > 1) {
ObjectInspector amtOI = paramOIs[1];
if (!ObjectInspectorUtils.isConstantObjectInspector(amtOI)
|| (amtOI.getCategory() != ObjectInspector.Category.PRIMITIVE)
|| ((PrimitiveObjectInspector) amtOI).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.INT) {
throw new UDFArgumentTypeException(1, fNm + " amount must be a integer value "
+ amtOI.getTypeName() + " was passed as parameter 1.");
}
Object o = ((ConstantObjectInspector) amtOI).getWritableConstantValue();
amt = ((IntWritable) o).get();
if (amt < 0) {
throw new UDFArgumentTypeException(1, fNm + " amount can not be nagative. Specified: " + amt );
}
}
if (paramOIs.length == 3) {
ObjectInspectorConverters.getConverter(paramOIs[2], paramOIs[0]);
}
GenericUDAFLeadLagEvaluator eval = createLLEvaluator();
eval.setAmt(amt);
return eval;
}
protected abstract String functionName();
protected abstract GenericUDAFLeadLagEvaluator createLLEvaluator();
public static abstract class GenericUDAFLeadLagEvaluator extends GenericUDAFEvaluator {
private transient ObjectInspector[] inputOI;
private int amt;
String fnName;
private transient Converter defaultValueConverter;
public GenericUDAFLeadLagEvaluator() {
}
/*
* used to initialize Streaming Evaluator.
*/
protected GenericUDAFLeadLagEvaluator(GenericUDAFLeadLagEvaluator src) {
this.inputOI = src.inputOI;
this.amt = src.amt;
this.fnName = src.fnName;
this.defaultValueConverter = src.defaultValueConverter;
this.mode = src.mode;
}
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
if (m != Mode.COMPLETE) {
throw new HiveException("Only COMPLETE mode supported for " + fnName + " function");
}
inputOI = parameters;
if (parameters.length == 3) {
defaultValueConverter = ObjectInspectorConverters
.getConverter(parameters[2], parameters[0]);
}
return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils
.getStandardObjectInspector(parameters[0]));
}
public int getAmt() {
return amt;
}
public void setAmt(int amt) {
this.amt = amt;
}
public String getFnName() {
return fnName;
}
public void setFnName(String fnName) {
this.fnName = fnName;
}
protected abstract LeadLagBuffer getNewLLBuffer() throws HiveException;
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
LeadLagBuffer lb = getNewLLBuffer();
lb.initialize(amt);
return lb;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
((LeadLagBuffer) agg).initialize(amt);
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
Object rowExprVal = ObjectInspectorUtils.copyToStandardObject(parameters[0], inputOI[0]);
Object defaultVal = parameters.length > 2 ? ObjectInspectorUtils.copyToStandardObject(
defaultValueConverter.convert(parameters[2]), inputOI[0]) : null;
((LeadLagBuffer) agg).addRow(rowExprVal, defaultVal);
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
throw new HiveException("terminatePartial not supported");
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
throw new HiveException("merge not supported");
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
return ((LeadLagBuffer) agg).terminate();
}
}
}