All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.presto.hive.functions.aggregation.HiveAggregationFunction Maven / Gradle / Ivy

The newest version!
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.hive.functions.aggregation;

import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.hive.functions.HiveFunction;
import com.facebook.presto.hive.functions.type.ObjectInspectors;
import com.facebook.presto.hive.functions.type.PrestoTypes;
import com.facebook.presto.spi.function.FunctionImplementationType;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunctionVisibility;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBridge;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver2;
import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;

import java.util.List;
import java.util.stream.Collectors;

import static com.facebook.presto.hive.functions.HiveFunctionErrorCode.initializationError;
import static com.facebook.presto.hive.functions.HiveFunctionErrorCode.unsupportedFunctionType;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;

public class HiveAggregationFunction
        extends HiveFunction
{
    private final FunctionMetadata functionMetadata;
    private final HiveAggregationFunctionImplementationFactory factory;

    public static HiveAggregationFunction createHiveAggregateFunction(Class cls,
                                                                      QualifiedObjectName name,
                                                                      List argumentTypes,
                                                                      TypeManager typeManager)
    {
        try {
            List inputTypes = argumentTypes
                    .stream()
                    .map(typeManager::getType)
                    .collect(Collectors.toList());
            ObjectInspector[] inputInspectors = inputTypes
                    .stream()
                    .map(inputType -> ObjectInspectors.create(inputType, typeManager))
                    .toArray(ObjectInspector[]::new);

            final ThreadLocal partialEvaluatorThreadLocal = ThreadLocal.withInitial(() -> {
                try {
                    GenericUDAFEvaluator partialEvaluator = getGenericUDAFEvaluator(cls, inputInspectors);
                    partialEvaluator.init(Mode.PARTIAL1, inputInspectors);
                    return partialEvaluator;
                }
                catch (HiveException e) {
                    throw initializationError(e);
                }
            });

            GenericUDAFEvaluator partialEvaluator = partialEvaluatorThreadLocal.get();
            ObjectInspector intermediateInspector = partialEvaluator.init(Mode.PARTIAL1, inputInspectors);
            Type intermediateType = PrestoTypes.fromObjectInspector(intermediateInspector, typeManager);

            final ThreadLocal finalEvaluatorThreadLocal = ThreadLocal.withInitial(() -> {
                try {
                    GenericUDAFEvaluator finalEvaluator = getGenericUDAFEvaluator(cls, inputInspectors);
                    finalEvaluator.init(Mode.FINAL, new ObjectInspector[] {intermediateInspector});
                    return finalEvaluator;
                }
                catch (HiveException e) {
                    throw initializationError(e);
                }
            });
            GenericUDAFEvaluator finalEvaluator = finalEvaluatorThreadLocal.get();
            ObjectInspector outputInspector = finalEvaluator.init(Mode.FINAL,
                    new ObjectInspector[] {intermediateInspector});
            Type outputType = PrestoTypes.fromObjectInspector(outputInspector, typeManager);

            Signature signature = new Signature(
                    name,
                    FunctionKind.AGGREGATE,
                    outputType.getTypeSignature(),
                    argumentTypes);
            HiveAggregationFunctionImplementationFactory factory = new HiveAggregationFunctionImplementationFactory(
                    signature,
                    inputTypes,
                    intermediateType,
                    outputType,
                    partialEvaluatorThreadLocal::get,
                    finalEvaluatorThreadLocal::get,
                    inputInspectors,
                    intermediateInspector,
                    outputInspector);
            return new HiveAggregationFunction(name, signature, "", factory);
        }
        catch (HiveException e) {
            throw initializationError(e);
        }
    }

    private HiveAggregationFunction(QualifiedObjectName name,
                                    Signature signature,
                                    String description,
                                    HiveAggregationFunctionImplementationFactory factory)
    {
        super(name, signature, false, true, true, description);
        this.factory = factory;
        this.functionMetadata = new FunctionMetadata(name,
                signature.getArgumentTypes(),
                signature.getReturnType(),
                FunctionKind.AGGREGATE,
                FunctionImplementationType.JAVA,
                true,
                true);
    }

    @Override
    public FunctionMetadata getFunctionMetadata()
    {
        return this.functionMetadata;
    }

    public HiveAggregationFunctionImplementation getImplementation()
    {
        return factory.create();
    }

    @SuppressWarnings("deprecation")
    private static GenericUDAFEvaluator getGenericUDAFEvaluator(Class cls, ObjectInspector[] arguments)
            throws HiveException
    {
        GenericUDAFResolver resolver = createGenericUDAFResolver(cls);
        GenericUDAFParameterInfo info = new SimpleGenericUDAFParameterInfo(arguments, false, false, false);
        if (resolver instanceof GenericUDAFResolver2) {
            return ((GenericUDAFResolver2) resolver).getEvaluator(info);
        }
        return resolver.getEvaluator(info.getParameters());
    }

    @SuppressWarnings("deprecation")
    private static GenericUDAFResolver createGenericUDAFResolver(Class cls)
            throws HiveException
    {
        try {
            if (GenericUDAFResolver.class.isAssignableFrom(cls)) {
                return ((GenericUDAFResolver) cls.getConstructor().newInstance());
            }
            else if (UDAF.class.isAssignableFrom(cls)) {
                Object udaf = cls.getConstructor().newInstance();
                verify(udaf instanceof UDAF);
                return new GenericUDAFBridge(((UDAF) udaf));
            }
        }
        catch (Exception e) {
            throw new HiveException(format("Instantiating %s error", cls), e);
        }
        throw unsupportedFunctionType(cls);
    }

    @Override
    public SqlFunctionVisibility getVisibility()
    {
        return null;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy