All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.scalar.RowToRowCast Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.scalar;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.hash.Hashing;
import com.google.common.io.BaseEncoding;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.trino.metadata.SqlScalarFunction;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.block.SqlRow;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionDependencyDeclaration;
import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.Signature;
import io.trino.spi.function.TypeVariableConstraint;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.sql.gen.CallSiteBinder;

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Objects;

import static io.airlift.bytecode.Access.FINAL;
import static io.airlift.bytecode.Access.PRIVATE;
import static io.airlift.bytecode.Access.PUBLIC;
import static io.airlift.bytecode.Access.STATIC;
import static io.airlift.bytecode.Access.a;
import static io.airlift.bytecode.Parameter.arg;
import static io.airlift.bytecode.ParameterizedType.type;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantNull;
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic;
import static io.airlift.bytecode.expression.BytecodeExpressions.newArray;
import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.spi.function.OperatorType.CAST;
import static io.trino.sql.gen.Bootstrap.BOOTSTRAP_METHOD;
import static io.trino.sql.gen.SqlTypeBytecodeExpression.constantType;
import static io.trino.type.UnknownType.UNKNOWN;
import static io.trino.util.CompilerUtils.defineClass;
import static io.trino.util.CompilerUtils.makeClassName;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.invoke.MethodHandles.collectArguments;
import static java.lang.invoke.MethodHandles.dropArguments;
import static java.lang.invoke.MethodHandles.explicitCastArguments;
import static java.lang.invoke.MethodHandles.guardWithTest;
import static java.lang.invoke.MethodHandles.lookup;
import static java.lang.invoke.MethodHandles.zero;
import static java.lang.invoke.MethodType.methodType;
import static java.nio.charset.StandardCharsets.UTF_8;

public class RowToRowCast
        extends SqlScalarFunction
{
    private static final MethodHandle OBJECT_IS_NULL;
    private static final MethodHandle BLOCK_IS_NULL;
    private static final MethodHandle WRITE_BOOLEAN;
    private static final MethodHandle WRITE_LONG;
    private static final MethodHandle WRITE_DOUBLE;
    private static final MethodHandle WRITE_OBJECT;
    private static final MethodHandle APPEND_NULL;

    static {
        try {
            OBJECT_IS_NULL = lookup().findStatic(Objects.class, "isNull", methodType(boolean.class, Object.class));
            BLOCK_IS_NULL = lookup().findVirtual(Block.class, "isNull", methodType(boolean.class, int.class));

            WRITE_BOOLEAN = lookup().findVirtual(Type.class, "writeBoolean", methodType(void.class, BlockBuilder.class, boolean.class));
            WRITE_LONG = lookup().findVirtual(Type.class, "writeLong", methodType(void.class, BlockBuilder.class, long.class));
            WRITE_DOUBLE = lookup().findVirtual(Type.class, "writeDouble", methodType(void.class, BlockBuilder.class, double.class));
            WRITE_OBJECT = lookup().findVirtual(Type.class, "writeObject", methodType(void.class, BlockBuilder.class, Object.class));
            APPEND_NULL = lookup().findVirtual(BlockBuilder.class, "appendNull", methodType(BlockBuilder.class));
        }
        catch (ReflectiveOperationException e) {
            throw new AssertionError(e);
        }
    }

    public static final RowToRowCast ROW_TO_ROW_CAST = new RowToRowCast();

    private RowToRowCast()
    {
        super(FunctionMetadata.operatorBuilder(CAST)
                .signature(Signature.builder()
                        .typeVariableConstraint(
                                // this is technically a recursive constraint for cast, but SignatureBinder has explicit handling for row-to-row cast
                                TypeVariableConstraint.builder("F")
                                        .variadicBound("row")
                                        .castableTo(new TypeSignature("T"))
                                        .build())
                        .variadicTypeParameter("T", "row")
                        .returnType(new TypeSignature("T"))
                        .argumentType(new TypeSignature("F"))
                        .build())
                .build());
    }

    @Override
    public FunctionDependencyDeclaration getFunctionDependencies(BoundSignature boundSignature)
    {
        List toTypes = boundSignature.getReturnType().getTypeParameters();
        List fromTypes = boundSignature.getArgumentType(0).getTypeParameters();

        FunctionDependencyDeclarationBuilder builder = FunctionDependencyDeclaration.builder();
        for (int i = 0; i < toTypes.size(); i++) {
            Type fromElementType = fromTypes.get(i);
            Type toElementType = toTypes.get(i);
            builder.addCast(fromElementType, toElementType);
        }
        return builder.build();
    }

    @Override
    public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies)
    {
        Type fromType = boundSignature.getArgumentType(0);
        Type toType = boundSignature.getReturnType();
        if (fromType.getTypeParameters().size() != toType.getTypeParameters().size()) {
            throw new TrinoException(StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "the size of fromType and toType must match");
        }
        Class castOperatorClass = generateRowCast(fromType, toType, functionDependencies);
        MethodHandle methodHandle = methodHandle(castOperatorClass, "castRow", ConnectorSession.class, SqlRow.class);
        return new ChoicesSpecializedSqlScalarFunction(
                boundSignature,
                FAIL_ON_NULL,
                ImmutableList.of(NEVER_NULL),
                methodHandle);
    }

    private static Class generateRowCast(Type fromType, Type toType, FunctionDependencies functionDependencies)
    {
        List toTypes = toType.getTypeParameters();
        List fromTypes = fromType.getTypeParameters();

        CallSiteBinder binder = new CallSiteBinder();

        // Embed the hash code of input and output types into the generated class name instead of the raw type names,
        // which ensures the class name does not hit the length limitation or invalid characters.
        byte[] hashSuffix = Hashing.goodFastHash(128).hashBytes((fromType + "$" + toType).getBytes(UTF_8)).asBytes();

        ClassDefinition definition = new ClassDefinition(
                a(PUBLIC, FINAL),
                makeClassName(Joiner.on("$").join("RowCast", BaseEncoding.base16().encode(hashSuffix))),
                type(Object.class));
        definition.declareDefaultConstructor(a(PRIVATE));

        Parameter session = arg("session", ConnectorSession.class);
        Parameter row = arg("row", SqlRow.class);

        MethodDefinition method = definition.declareMethod(
                a(PUBLIC, STATIC),
                "castRow",
                type(SqlRow.class),
                session,
                row);

        Scope scope = method.getScope();
        BytecodeBlock body = method.getBody();

        Variable fieldBlocks = scope.declareVariable("fieldBlocks", body, newArray(type(Block[].class), toTypes.size()));
        Variable rawIndex = scope.declareVariable("rawIndex", body, row.invoke("getRawIndex", int.class));

        Variable fieldBuilder = scope.declareVariable(BlockBuilder.class, "fieldBuilder");
        for (int i = 0; i < toTypes.size(); i++) {
            Type fromElementType = fromTypes.get(i);
            Type toElementType = toTypes.get(i);

            body.append(fieldBuilder.set(constantType(binder, toElementType).invoke(
                    "createBlockBuilder",
                    BlockBuilder.class,
                    constantNull(BlockBuilderStatus.class),
                    constantInt(1))));

            if (fromElementType.equals(UNKNOWN)) {
                body.append(fieldBuilder.invoke("appendNull", BlockBuilder.class).pop());
            }
            else {
                MethodHandle castMethod = getNullSafeCast(functionDependencies, fromElementType, toElementType);
                MethodHandle writeMethod = getNullSafeWrite(toElementType);
                MethodHandle castAndWrite = collectArguments(writeMethod, 1, castMethod);
                body.append(invokeDynamic(
                        BOOTSTRAP_METHOD,
                        ImmutableList.of(binder.bind(castAndWrite).getBindingId()),
                        "castAndWriteField",
                        castAndWrite.type(),
                        fieldBuilder,
                        session,
                        row.invoke("getRawFieldBlock", Block.class, constantInt(i)),
                        rawIndex));
            }
            body.append(fieldBlocks.setElement(i, fieldBuilder.invoke("build", Block.class)));
        }

        body.append(newInstance(SqlRow.class, constantInt(0), fieldBlocks).ret());
        return defineClass(definition, Object.class, binder.getBindings(), RowToRowCast.class.getClassLoader());
    }

    private static MethodHandle getNullSafeWrite(Type type)
    {
        MethodHandle writeMethod;
        if (type.getJavaType() == boolean.class) {
            writeMethod = WRITE_BOOLEAN;
        }
        else if (type.getJavaType() == long.class) {
            writeMethod = WRITE_LONG;
        }
        else if (type.getJavaType() == double.class) {
            writeMethod = WRITE_DOUBLE;
        }
        else {
            writeMethod = WRITE_OBJECT;
        }
        writeMethod = writeMethod.bindTo(type);
        writeMethod = explicitCastArguments(writeMethod, methodType(void.class, BlockBuilder.class, Object.class));
        MethodHandle isNull = dropArguments(OBJECT_IS_NULL, 0, BlockBuilder.class);
        MethodHandle appendNull = dropArguments(APPEND_NULL, 1, Object.class).asType(writeMethod.type());
        return guardWithTest(isNull, appendNull, writeMethod);
    }

    private static MethodHandle getNullSafeCast(FunctionDependencies functionDependencies, Type fromElementType, Type toElementType)
    {
        MethodHandle castMethod = functionDependencies.getCastImplementation(
                fromElementType,
                toElementType,
                new InvocationConvention(ImmutableList.of(BLOCK_POSITION_NOT_NULL), NULLABLE_RETURN, true, false))
                .getMethodHandle();

        // normalize so cast always has a session
        if (!castMethod.type().parameterType(0).equals(ConnectorSession.class)) {
            castMethod = dropArguments(castMethod, 0, ConnectorSession.class);
        }

        // change return to Object
        castMethod = castMethod.asType(methodType(Object.class, ConnectorSession.class, Block.class, int.class));

        // if block is null, return null. otherwise execute the cast
        return guardWithTest(
                dropArguments(BLOCK_IS_NULL, 0, ConnectorSession.class),
                dropArguments(zero(Object.class), 0, castMethod.type().parameterList()),
                castMethod);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy