All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prestosql.operator.aggregation.minmaxby.AbstractMinMaxBy Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.operator.aggregation.minmaxby;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionDependencies;
import io.prestosql.metadata.FunctionDependencyDeclaration;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.operator.aggregation.AccumulatorCompiler;
import io.prestosql.operator.aggregation.AggregationMetadata;
import io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.prestosql.operator.aggregation.GenericAccumulatorFactoryBinder;
import io.prestosql.operator.aggregation.InternalAggregationFunction;
import io.prestosql.operator.aggregation.state.StateCompiler;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.function.AccumulatorStateFactory;
import io.prestosql.spi.function.AccumulatorStateSerializer;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.sql.gen.CallSiteBinder;
import io.prestosql.sql.gen.SqlTypeBytecodeExpression;
import io.prestosql.util.MinMaxCompare;

import java.lang.invoke.MethodHandle;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.bytecode.Access.FINAL;
import static io.airlift.bytecode.Access.PRIVATE;
import static io.airlift.bytecode.Access.PUBLIC;
import static io.airlift.bytecode.Access.STATIC;
import static io.airlift.bytecode.Access.a;
import static io.airlift.bytecode.Parameter.arg;
import static io.airlift.bytecode.ParameterizedType.type;
import static io.airlift.bytecode.expression.BytecodeExpressions.and;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantBoolean;
import static io.airlift.bytecode.expression.BytecodeExpressions.not;
import static io.airlift.bytecode.expression.BytecodeExpressions.or;
import static io.prestosql.metadata.FunctionKind.AGGREGATE;
import static io.prestosql.metadata.Signature.orderableTypeParameter;
import static io.prestosql.metadata.Signature.typeVariable;
import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata;
import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL;
import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static io.prestosql.operator.aggregation.AggregationUtils.generateAggregationName;
import static io.prestosql.operator.aggregation.minmaxby.TwoNullableValueStateMapping.getStateClass;
import static io.prestosql.operator.aggregation.minmaxby.TwoNullableValueStateMapping.getStateSerializer;
import static io.prestosql.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.prestosql.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.prestosql.spi.function.InvocationConvention.simpleConvention;
import static io.prestosql.spi.function.OperatorType.COMPARISON;
import static io.prestosql.sql.gen.BytecodeUtils.loadConstant;
import static io.prestosql.sql.gen.SqlTypeBytecodeExpression.constantType;
import static io.prestosql.util.CompilerUtils.defineClass;
import static io.prestosql.util.CompilerUtils.makeClassName;
import static io.prestosql.util.Reflection.methodHandle;
import static java.util.Arrays.stream;

public abstract class AbstractMinMaxBy
        extends SqlAggregationFunction
{
    private final boolean min;

    protected AbstractMinMaxBy(boolean min, String description)
    {
        super(
                new FunctionMetadata(
                        new Signature(
                                (min ? "min" : "max") + "_by",
                                ImmutableList.of(orderableTypeParameter("K"), typeVariable("V")),
                                ImmutableList.of(),
                                new TypeSignature("V"),
                                ImmutableList.of(new TypeSignature("V"), new TypeSignature("K")),
                                false),
                        true,
                        ImmutableList.of(
                                new FunctionArgumentDefinition(true),
                                new FunctionArgumentDefinition(false)),
                        false,
                        true,
                        description,
                        AGGREGATE),
                true,
                false);
        this.min = min;
    }

    @Override
    public FunctionDependencyDeclaration getFunctionDependencies()
    {
        return FunctionDependencyDeclaration.builder()
                .addOperatorSignature(COMPARISON, ImmutableList.of(new TypeSignature("K"), new TypeSignature("K")))
                .build();
    }

    @Override
    public List getIntermediateTypes(FunctionBinding functionBinding)
    {
        Type keyType = functionBinding.getTypeVariable("K");
        Type valueType = functionBinding.getTypeVariable("V");

        Class stateClass = getStateClass(keyType.getJavaType(), valueType.getJavaType());
        if (valueType.getJavaType().isPrimitive()) {
            Map stateFieldTypes = ImmutableMap.of("First", keyType, "Second", valueType);
            return ImmutableList.of(StateCompiler.getSerializedType(stateClass, stateFieldTypes).getTypeSignature());
        }

        return ImmutableList.of(getStateSerializer(keyType, valueType).getSerializedType().getTypeSignature());
    }

    @Override
    public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
    {
        Type keyType = functionBinding.getTypeVariable("K");
        Type valueType = functionBinding.getTypeVariable("V");
        return generateAggregation(valueType, keyType, functionDependencies);
    }

    private InternalAggregationFunction generateAggregation(Type valueType, Type keyType, FunctionDependencies functionDependencies)
    {
        Class stateClass = getStateClass(keyType.getJavaType(), valueType.getJavaType());
        DynamicClassLoader classLoader = new DynamicClassLoader(getClass().getClassLoader());

        // Generate states and serializers:
        // For value that is a Block or Slice, we store them as Block/position combination
        // to avoid generating long-living objects through getSlice or getObject.
        // This can also help reducing cross-region reference in G1GC engine.
        // TODO: keys can have the same problem. But usually they are primitive types (given the nature of comparison).
        AccumulatorStateFactory stateFactory;
        AccumulatorStateSerializer stateSerializer;
        if (valueType.getJavaType().isPrimitive()) {
            Map stateFieldTypes = ImmutableMap.of("First", keyType, "Second", valueType);
            stateFactory = StateCompiler.generateStateFactory(stateClass, stateFieldTypes, classLoader);
            stateSerializer = StateCompiler.generateStateSerializer(stateClass, stateFieldTypes, classLoader);
        }
        else {
            // StateCompiler checks type compatibility.
            // Given "Second" in this case is always a Block, we only need to make sure the getter and setter of the Blocks are properly generated.
            // We deliberately make "SecondBlock" an array type so that the compiler will treat it as a block to workaround the sanity check.
            stateFactory = StateCompiler.generateStateFactory(stateClass, ImmutableMap.of("First", keyType, "SecondBlock", new ArrayType(valueType)), classLoader);

            // States can be generated by StateCompiler given the they are simply classes with getters and setters.
            // However, serializers have logic in it. Creating serializers is better than generating them.
            stateSerializer = getStateSerializer(keyType, valueType);
        }

        Type intermediateType = stateSerializer.getSerializedType();

        List inputTypes = ImmutableList.of(valueType, keyType);

        CallSiteBinder binder = new CallSiteBinder();
        MethodHandle compareMethod = MinMaxCompare.getMinMaxCompare(functionDependencies, keyType, simpleConvention(FAIL_ON_NULL, NEVER_NULL, NEVER_NULL), min);

        ClassDefinition definition = new ClassDefinition(
                a(PUBLIC, FINAL),
                makeClassName("processMaxOrMinBy"),
                type(Object.class));
        definition.declareDefaultConstructor(a(PRIVATE));
        generateInputMethod(definition, binder, compareMethod, keyType, valueType, stateClass);
        generateCombineMethod(definition, binder, compareMethod, valueType, stateClass);
        generateOutputMethod(definition, binder, valueType, stateClass);
        Class generatedClass = defineClass(definition, Object.class, binder.getBindings(), classLoader);
        MethodHandle inputMethod = methodHandle(generatedClass, "input", stateClass, Block.class, Block.class, int.class);
        MethodHandle combineMethod = methodHandle(generatedClass, "combine", stateClass, stateClass);
        MethodHandle outputMethod = methodHandle(generatedClass, "output", stateClass, BlockBuilder.class);
        String name = getFunctionMetadata().getSignature().getName();
        AggregationMetadata aggregationMetadata = new AggregationMetadata(
                generateAggregationName(name, valueType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())),
                createInputParameterMetadata(valueType, keyType),
                inputMethod,
                Optional.empty(),
                combineMethod,
                outputMethod,
                ImmutableList.of(new AccumulatorStateDescriptor(
                        stateClass,
                        stateSerializer,
                        stateFactory)),
                valueType);
        GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(aggregationMetadata, classLoader);
        return new InternalAggregationFunction(name, inputTypes, ImmutableList.of(intermediateType), valueType, factory);
    }

    private static List createInputParameterMetadata(Type value, Type key)
    {
        return ImmutableList.of(new ParameterMetadata(STATE), new ParameterMetadata(NULLABLE_BLOCK_INPUT_CHANNEL, value), new ParameterMetadata(BLOCK_INPUT_CHANNEL, key), new ParameterMetadata(BLOCK_INDEX));
    }

    private static void generateInputMethod(ClassDefinition definition, CallSiteBinder binder, MethodHandle compareMethod, Type keyType, Type valueType, Class stateClass)
    {
        Parameter state = arg("state", stateClass);
        Parameter value = arg("value", Block.class);
        Parameter key = arg("key", Block.class);
        Parameter position = arg("position", int.class);
        MethodDefinition method = definition.declareMethod(a(PUBLIC, STATIC), "input", type(void.class), state, value, key, position);
        SqlTypeBytecodeExpression keySqlType = constantType(binder, keyType);

        BytecodeBlock ifBlock = new BytecodeBlock()
                .append(invokeMethod(stateClass, state, "setFirst", keySqlType.getValue(key, position)))
                .append(state.invoke("setFirstNull", void.class, constantBoolean(false)))
                .append(state.invoke("setSecondNull", void.class, value.invoke("isNull", boolean.class, position)));
        BytecodeNode setValueNode;
        if (valueType.getJavaType().isPrimitive()) {
            SqlTypeBytecodeExpression valueSqlType = constantType(binder, valueType);
            setValueNode = invokeMethod(stateClass, state, "setSecond", valueSqlType.getValue(value, position));
        }
        else {
            // Do not get value directly given it creates object overhead.
            // Such objects would live long enough in Block or SliceBigArray to cause GC pressure.
            setValueNode = new BytecodeBlock()
                    .append(state.invoke("setSecondBlock", void.class, value))
                    .append(state.invoke("setSecondPosition", void.class, position));
        }
        ifBlock.append(new IfStatement()
                .condition(value.invoke("isNull", boolean.class, position))
                .ifFalse(setValueNode));

        method.getBody().append(new IfStatement()
                .condition(or(
                        state.invoke("isFirstNull", boolean.class),
                        and(
                                not(key.invoke("isNull", boolean.class, position)),
                                loadConstant(binder, compareMethod, MethodHandle.class).invoke(
                                        "invokeExact",
                                        boolean.class,
                                        keySqlType.getValue(key, position).cast(compareMethod.type().parameterType(0)),
                                        invokeMethod(stateClass, state, "getFirst").cast(compareMethod.type().parameterType(1))))))
                .ifTrue(ifBlock))
                .ret();
    }

    private static void generateCombineMethod(ClassDefinition definition, CallSiteBinder binder, MethodHandle compareMethod, Type valueType, Class stateClass)
    {
        Parameter state = arg("state", stateClass);
        Parameter otherState = arg("otherState", stateClass);
        MethodDefinition method = definition.declareMethod(a(PUBLIC, STATIC), "combine", type(void.class), state, otherState);

        BytecodeBlock ifBlock = new BytecodeBlock()
                .append(invokeMethod(stateClass, state, "setFirst", invokeMethod(stateClass, otherState, "getFirst")))
                .append(state.invoke("setFirstNull", void.class, otherState.invoke("isFirstNull", boolean.class)))
                .append(state.invoke("setSecondNull", void.class, otherState.invoke("isSecondNull", boolean.class)));
        if (valueType.getJavaType().isPrimitive()) {
            ifBlock.append(invokeMethod(stateClass, state, "setSecond", otherState.invoke("getSecond", valueType.getJavaType())));
        }
        else {
            ifBlock.append(new BytecodeBlock()
                    .append(state.invoke("setSecondBlock", void.class, otherState.invoke("getSecondBlock", Block.class)))
                    .append(state.invoke("setSecondPosition", void.class, otherState.invoke("getSecondPosition", int.class))));
        }

        method.getBody()
                .append(new IfStatement()
                        .condition(or(
                                state.invoke("isFirstNull", boolean.class),
                                and(
                                        not(otherState.invoke("isFirstNull", boolean.class)),
                                        loadConstant(binder, compareMethod, MethodHandle.class).invoke(
                                                "invokeExact",
                                                boolean.class,
                                                invokeMethod(stateClass, otherState, "getFirst").cast(compareMethod.type().parameterType(0)),
                                                invokeMethod(stateClass, state, "getFirst").cast(compareMethod.type().parameterType(1))))))
                        .ifTrue(ifBlock))
                .ret();
    }

    private static BytecodeExpression invokeMethod(Class instanceType, Parameter instance, String methodName, BytecodeExpression... arguments)
    {
        Method method = getMethod(instanceType, methodName);
        Class[] parameterTypes = method.getParameterTypes();
        checkArgument(parameterTypes.length == arguments.length, "Expected %s arguments, but got %s", parameterTypes.length, arguments.length);

        ImmutableList.Builder castedArguments = ImmutableList.builder();
        for (int i = 0; i < arguments.length; i++) {
            BytecodeExpression argument = arguments[i];
            Class parameterType = parameterTypes[i];
            castedArguments.add(argument.cast(parameterType));
        }
        return instance.invoke(method, castedArguments.build());
    }

    private static void generateOutputMethod(ClassDefinition definition, CallSiteBinder binder, Type valueType, Class stateClass)
    {
        Parameter state = arg("state", stateClass);
        Parameter out = arg("out", BlockBuilder.class);
        MethodDefinition method = definition.declareMethod(a(PUBLIC, STATIC), "output", type(void.class), state, out);

        IfStatement ifStatement = new IfStatement()
                .condition(or(state.invoke("isFirstNull", boolean.class), state.invoke("isSecondNull", boolean.class)))
                .ifTrue(new BytecodeBlock().append(out.invoke("appendNull", BlockBuilder.class)).pop());
        SqlTypeBytecodeExpression valueSqlType = constantType(binder, valueType);
        BytecodeExpression getValueExpression;
        if (valueType.getJavaType().isPrimitive()) {
            getValueExpression = state.invoke("getSecond", valueType.getJavaType());
        }
        else {
            getValueExpression = valueSqlType.getValue(state.invoke("getSecondBlock", Block.class), state.invoke("getSecondPosition", int.class));
        }
        ifStatement.ifFalse(valueSqlType.writeValue(out, getValueExpression));
        method.getBody().append(ifStatement).ret();
    }

    private static Method getMethod(Class stateClass, String name)
    {
        return stream(stateClass.getMethods())
                .filter(method -> method.getName().equals(name))
                .findFirst()
                .orElseThrow(() -> new IllegalArgumentException("State class does not have a method named " + name));
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy