All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.hive.functions.HiveUDAFFunction Maven / Gradle / Ivy

There is a newer version: 1.5.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.hive.functions;

import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.table.dataformat.GenericRow;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.FunctionContext;

import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBridge;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver2;
import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

import java.util.ArrayList;
import java.util.List;

/**
 * A Flink Aggregate Function wrapper which wraps a Hive UDAF.
 */
public class HiveUDAFFunction extends AggregateFunction {

	private final HiveFunctionWrapper hiveFunctionWrapper;
	private final boolean isUDAFBridgeRequired;

	private transient GenericUDAFResolver2 resolver;
	private transient GenericUDAFEvaluator finalEvaluator;
	private transient boolean finalEvaluatorByVoid;
	private transient GenericUDAFEvaluator partial1Evaluator;
	private transient boolean partial1EvaluatorByVoid;
	private transient GenericUDAFEvaluator partial2Evaluator;
	private transient boolean partial2EvaluatorByVoid;
	private transient ObjectInspector[] inputInspectors;
	private transient ObjectInspector returnInspector;
	private transient ObjectInspector partialResultInspector;

	public HiveUDAFFunction(
			HiveFunctionWrapper hiveFunctionWrapper) throws ClassNotFoundException {

		this.hiveFunctionWrapper = hiveFunctionWrapper;
		this.isUDAFBridgeRequired = hiveFunctionWrapper.getUDFClass().equals(UDAF.class);
		this.finalEvaluatorByVoid = true;
		this.partial1EvaluatorByVoid = true;
		this.partial2EvaluatorByVoid = true;
		this.inputInspectors = null;
	}

	private GenericUDAFResolver2 newResolver()
			throws IllegalAccessException, InstantiationException, ClassNotFoundException {

		if (isUDAFBridgeRequired) {
			return new GenericUDAFBridge(
					(UDAF) hiveFunctionWrapper.createFunction());
		} else {
			return (GenericUDAFResolver2) hiveFunctionWrapper.createFunction();
		}
	}

	public void accumulate(
			GenericUDAFEvaluator.AggregationBuffer acc,
			Object... params) {

		if (null == inputInspectors
				|| finalEvaluatorByVoid
				|| partial1EvaluatorByVoid
				|| partial2EvaluatorByVoid) {

			List constans = new ArrayList<>(params.length);
			for (Object ignored : params) {
				constans.add(false);
			}
			inputInspectors = HiveInspectors.toInspectors(params, constans);

			try {
				if (finalEvaluatorByVoid) {
					finalEvaluator = null;
					getFinalEvaluator();
				}
				if (partial1EvaluatorByVoid) {
					partial1Evaluator = null;
					getPartial1Evaluator();
				}
				if (partial2EvaluatorByVoid) {
					partial2Evaluator = null;
					getPartial2Evaluator();
				}
			} catch (HiveException e) {
				throw new RuntimeException(e);
			}
		}

		try {
			getPartial1Evaluator().iterate(acc, params);
		} catch (HiveException e) {
			throw new RuntimeException(e);
		}
	}

	// TODO open this block after BLINK-17364558 resolved.
/*	public void merge(
			GenericUDAFEvaluator.AggregationBuffer acc,
			Iterable it) {

		try {
			for (GenericUDAFEvaluator.AggregationBuffer agg : it) {

				getPartial2Evaluator().merge(
						acc,
						getPartial2Evaluator().terminatePartial(agg));
			}
		} catch (HiveException e) {
			throw new RuntimeException(e);
		}
	}*/

	public void resetAccumulator(GenericUDAFEvaluator.AggregationBuffer acc) {

		try {
			getPartial1Evaluator().reset(acc);
			getPartial2Evaluator().reset(acc);
			getFinalEvaluator().reset(acc);
		} catch (HiveException e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public void open(FunctionContext context) throws Exception {

		this.resolver = newResolver();
		this.partial1Evaluator = null;
		this.partial2Evaluator = null;
		this.finalEvaluator = null;
	}

	private GenericUDAFEvaluator getFinalEvaluator() throws HiveException {

		if (finalEvaluatorByVoid) {
			// If it is not from the real finalEvaluator
			finalEvaluator = null;
		}
		if (null == finalEvaluator) {
			// If real params are null, use the one Void params.
			if (null == inputInspectors) {
				SimpleGenericUDAFParameterInfo paramInfo = getLazyVoidOneParam();
				finalEvaluator = resolver.getEvaluator(paramInfo);
				inputInspectors = paramInfo.getParameterObjectInspectors();
				finalEvaluatorByVoid = true;
			} else {
				SimpleGenericUDAFParameterInfo paramInfo = new SimpleGenericUDAFParameterInfo(
						inputInspectors, false, false, false);
				finalEvaluator = resolver.getEvaluator(paramInfo);
				finalEvaluatorByVoid = false;
			}
			returnInspector = finalEvaluator.init(
					GenericUDAFEvaluator.Mode.FINAL,
					inputInspectors);
		}
		return this.finalEvaluator;
	}

	private GenericUDAFEvaluator getPartial1Evaluator() throws HiveException {

		if (partial1EvaluatorByVoid) {
			partial1Evaluator = null;
		}
		if (null == partial1Evaluator) {
			if (null == inputInspectors) {
				SimpleGenericUDAFParameterInfo paramInfo = getLazyVoidOneParam();
				partial1Evaluator = resolver.getEvaluator(paramInfo);
				inputInspectors = paramInfo.getParameterObjectInspectors();
				partial1EvaluatorByVoid = true;
			} else {
				SimpleGenericUDAFParameterInfo paramInfo = new SimpleGenericUDAFParameterInfo(
						inputInspectors, false, false, false);
				partial1Evaluator = resolver.getEvaluator(paramInfo);
				partial1EvaluatorByVoid = false;
			}
			partialResultInspector = partial1Evaluator.init(
					GenericUDAFEvaluator.Mode.PARTIAL1,
					inputInspectors);
		}
		return this.partial1Evaluator;
	}

	private GenericUDAFEvaluator getPartial2Evaluator() throws HiveException {

		if (partial2EvaluatorByVoid) {
			partial2Evaluator = null;
		}
		if (null == partial2Evaluator) {
			if (null == partialResultInspector) {
				SimpleGenericUDAFParameterInfo parameterInfo = getLazyVoidOneParam();
				partial2Evaluator = resolver.getEvaluator(parameterInfo);
				partialResultInspector = parameterInfo.getParameterObjectInspectors()[0];
				partial2EvaluatorByVoid = true;
			} else {
				SimpleGenericUDAFParameterInfo paramInfo = new SimpleGenericUDAFParameterInfo(
						inputInspectors, false, false, false);
				partial2Evaluator = resolver.getEvaluator(paramInfo);
				partial2EvaluatorByVoid = false;
			}
			// The input of Partial 2 is th Output of the Partial 1.
			ObjectInspector[] partial1Types = new ObjectInspector[1];
			partial1Types[0] = partialResultInspector;
			partial2Evaluator.init(
					GenericUDAFEvaluator.Mode.PARTIAL2,
					partial1Types);
		}
		return this.partial2Evaluator;
	}

	private SimpleGenericUDAFParameterInfo getLazyVoidOneParam() {

		ObjectInspector[] objectInspectors = new ObjectInspector[1];
		objectInspectors[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
				PrimitiveObjectInspector.PrimitiveCategory.LONG);
		return new SimpleGenericUDAFParameterInfo(
				objectInspectors, false, false, false);
	}

	@Override
	public GenericUDAFEvaluator.AggregationBuffer createAccumulator() {

		try {
			if (null == resolver) {
				// This method may be called at the client side.
				resolver = newResolver();
			}
			return getPartial1Evaluator().getNewAggregationBuffer();
		} catch (HiveException | IllegalAccessException | InstantiationException | ClassNotFoundException e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public BaseRow getValue(GenericUDAFEvaluator.AggregationBuffer accumulator) {

		try {
			Object result = getFinalEvaluator().terminate(accumulator);
			Object flinkResult = HiveInspectors.unwrap(result, returnInspector);
			GenericRow value = new GenericRow(1);
			value.update(0, flinkResult);
			return value;
		} catch (HiveException e) {
			throw new RuntimeException(e);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy