All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.scalar.MapToMapCast Maven / Gradle / Ivy

There is a newer version: 468
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.metadata.SqlScalarFunction;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.DuplicateMapKeyException;
import io.trino.spi.block.SqlMap;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionDependencyDeclaration;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.Signature;
import io.trino.spi.type.MapType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.block.MapHashTables.HashBuildMode.STRICT_NOT_DISTINCT_FROM;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.spi.function.OperatorType.CAST;
import static io.trino.spi.type.TypeSignature.mapType;
import static io.trino.util.Failures.internalError;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.invoke.MethodHandles.dropArguments;
import static java.lang.invoke.MethodHandles.foldArguments;
import static java.lang.invoke.MethodHandles.permuteArguments;
import static java.lang.invoke.MethodType.methodType;
import static java.util.Objects.requireNonNull;

public final class MapToMapCast
        extends SqlScalarFunction
{
    private static final MethodHandle METHOD_HANDLE = methodHandle(
            MapToMapCast.class,
            "mapCast",
            MethodHandle.class,
            MethodHandle.class,
            MapType.class,
            BlockPositionIsDistinctFrom.class,
            BlockPositionHashCode.class,
            ConnectorSession.class,
            SqlMap.class);

    private static final MethodHandle CHECK_LONG_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkLongIsNotNull", Long.class);
    private static final MethodHandle CHECK_DOUBLE_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkDoubleIsNotNull", Double.class);
    private static final MethodHandle CHECK_BOOLEAN_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkBooleanIsNotNull", Boolean.class);
    private static final MethodHandle CHECK_SLICE_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkSliceIsNotNull", Slice.class);
    private static final MethodHandle CHECK_BLOCK_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkBlockIsNotNull", Block.class);

    private static final MethodHandle WRITE_LONG = methodHandle(Type.class, "writeLong", BlockBuilder.class, long.class);
    private static final MethodHandle WRITE_DOUBLE = methodHandle(Type.class, "writeDouble", BlockBuilder.class, double.class);
    private static final MethodHandle WRITE_BOOLEAN = methodHandle(Type.class, "writeBoolean", BlockBuilder.class, boolean.class);
    private static final MethodHandle WRITE_OBJECT = methodHandle(Type.class, "writeObject", BlockBuilder.class, Object.class);

    private final BlockTypeOperators blockTypeOperators;

    public MapToMapCast(BlockTypeOperators blockTypeOperators)
    {
        super(FunctionMetadata.operatorBuilder(CAST)
                .signature(Signature.builder()
                        .castableToTypeParameter("FK", new TypeSignature("TK"))
                        .castableToTypeParameter("FV", new TypeSignature("TV"))
                        .typeVariable("TK")
                        .typeVariable("TV")
                        .returnType(mapType(new TypeSignature("TK"), new TypeSignature("TV")))
                        .argumentType(mapType(new TypeSignature("FK"), new TypeSignature("FV")))
                        .build())
                .nullable()
                .build());
        this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null");
    }

    @Override
    public FunctionDependencyDeclaration getFunctionDependencies()
    {
        return FunctionDependencyDeclaration.builder()
                .addCastSignature(new TypeSignature("FK"), new TypeSignature("TK"))
                .addCastSignature(new TypeSignature("FV"), new TypeSignature("TV"))
                .build();
    }

    @Override
    public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies)
    {
        checkArgument(boundSignature.getArity() == 1, "Expected arity to be 1");
        MapType fromMapType = (MapType) boundSignature.getArgumentType(0);
        Type fromKeyType = fromMapType.getKeyType();
        Type fromValueType = fromMapType.getValueType();
        MapType toMapType = (MapType) boundSignature.getReturnType();
        Type toKeyType = toMapType.getKeyType();
        Type toValueType = toMapType.getValueType();

        MethodHandle keyProcessor = buildProcessor(functionDependencies, fromKeyType, toKeyType, true);
        MethodHandle valueProcessor = buildProcessor(functionDependencies, fromValueType, toValueType, false);
        BlockPositionIsDistinctFrom keyEqual = blockTypeOperators.getDistinctFromOperator(toKeyType);
        BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(toKeyType);
        MethodHandle target = MethodHandles.insertArguments(METHOD_HANDLE, 0, keyProcessor, valueProcessor, toMapType, keyEqual, keyHashCode);
        return new ChoicesSpecializedSqlScalarFunction(boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), target);
    }

    /**
     * The signature of the returned MethodHandle is (Block fromMap, int position, ConnectorSession session, BlockBuilder mapBlockBuilder)void.
     * The processor will get the value from fromMap, cast it and write to toBlock.
     */
    private static MethodHandle buildProcessor(FunctionDependencies functionDependencies, Type fromType, Type toType, boolean isKey)
    {
        // Get block position cast, with optional connector session
        FunctionNullability functionNullability = functionDependencies.getCastNullability(fromType, toType);
        InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(BLOCK_POSITION), functionNullability.isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, true, false);
        MethodHandle cast = functionDependencies.getCastImplementation(fromType, toType, invocationConvention).getMethodHandle();
        // Normalize cast to have connector session as first argument
        if (cast.type().parameterArray()[0] != ConnectorSession.class) {
            cast = dropArguments(cast, 0, ConnectorSession.class);
        }
        // Change cast signature to (Block.class, int.class, ConnectorSession.class):T
        cast = permuteArguments(cast, methodType(cast.type().returnType(), Block.class, int.class, ConnectorSession.class), 2, 0, 1);

        // If the key cast function is nullable, check the result is not null
        if (isKey && functionNullability.isReturnNullable()) {
            MethodHandle target = nullChecker(cast.type().returnType());
            cast = foldArguments(dropArguments(target, 1, cast.type().parameterList()), cast);
        }

        // get write method with signature: (T, BlockBuilder.class):void
        MethodHandle writer = nativeValueWriter(toType);
        writer = permuteArguments(writer, methodType(void.class, writer.type().parameterArray()[1], BlockBuilder.class), 1, 0);

        // ensure cast function returns the type expected by the writer
        cast = cast.asType(methodType(writer.type().parameterType(0), cast.type().parameterArray()));

        return foldArguments(dropArguments(writer, 1, cast.type().parameterList()), cast);
    }

    /**
     * Returns a null checker MethodHandle that only returns the value when it is not null.
     * 

* The signature of the returned MethodHandle could be one of the following depending on javaType: *

    *
  • (Long value)long *
  • (Double value)double *
  • (Boolean value)boolean *
  • (Slice value)Slice *
  • (Block value)Block *
*/ private static MethodHandle nullChecker(Class javaType) { if (javaType == Long.class) { return CHECK_LONG_IS_NOT_NULL; } if (javaType == Double.class) { return CHECK_DOUBLE_IS_NOT_NULL; } if (javaType == Boolean.class) { return CHECK_BOOLEAN_IS_NOT_NULL; } if (javaType == Slice.class) { return CHECK_SLICE_IS_NOT_NULL; } if (javaType == Block.class) { return CHECK_BLOCK_IS_NOT_NULL; } throw new IllegalArgumentException("Unknown java type " + javaType); } @UsedByGeneratedCode public static long checkLongIsNotNull(Long value) { if (value == null) { throw new TrinoException(INVALID_CAST_ARGUMENT, "map key is null"); } return value; } @UsedByGeneratedCode public static double checkDoubleIsNotNull(Double value) { if (value == null) { throw new TrinoException(INVALID_CAST_ARGUMENT, "map key is null"); } return value; } @UsedByGeneratedCode public static boolean checkBooleanIsNotNull(Boolean value) { if (value == null) { throw new TrinoException(INVALID_CAST_ARGUMENT, "map key is null"); } return value; } @UsedByGeneratedCode public static Slice checkSliceIsNotNull(Slice value) { if (value == null) { throw new TrinoException(INVALID_CAST_ARGUMENT, "map key is null"); } return value; } @UsedByGeneratedCode public static Block checkBlockIsNotNull(Block value) { if (value == null) { throw new TrinoException(INVALID_CAST_ARGUMENT, "map key is null"); } return value; } @UsedByGeneratedCode public static SqlMap mapCast( MethodHandle keyProcessFunction, MethodHandle valueProcessFunction, MapType toType, BlockPositionIsDistinctFrom keyDistinctOperator, BlockPositionHashCode keyHashCode, ConnectorSession session, SqlMap fromMap) { int size = fromMap.getSize(); int rawOffset = fromMap.getRawOffset(); Block rawKeyBlock = fromMap.getRawKeyBlock(); Block rawValueBlock = fromMap.getRawValueBlock(); // Cast the keys into a new block Type toKeyType = toType.getKeyType(); BlockBuilder keyBlockBuilder = toKeyType.createBlockBuilder(null, size); for (int i = 0; i < size; i++) { try { keyProcessFunction.invokeExact(rawKeyBlock, rawOffset + i, session, keyBlockBuilder); } catch (Throwable t) { throw internalError(t); } } Block keyBlock = keyBlockBuilder.build(); // Cast the values into a new block Type toValueType = toType.getValueType(); BlockBuilder valueBlockBuilder = toValueType.createBlockBuilder(null, size); for (int i = 0; i < size; i++) { if (rawValueBlock.isNull(rawOffset + i)) { valueBlockBuilder.appendNull(); continue; } try { valueProcessFunction.invokeExact(rawValueBlock, rawOffset + i, session, valueBlockBuilder); } catch (Throwable t) { throw internalError(t); } } Block valueBlock = valueBlockBuilder.build(); try { return new SqlMap(toType, STRICT_NOT_DISTINCT_FROM, keyBlock, valueBlock); } catch (DuplicateMapKeyException e) { throw new TrinoException(INVALID_CAST_ARGUMENT, "duplicate keys"); } } private static MethodHandle nativeValueWriter(Type type) { Class javaType = type.getJavaType(); MethodHandle methodHandle; if (javaType == long.class) { methodHandle = WRITE_LONG; } else if (javaType == double.class) { methodHandle = WRITE_DOUBLE; } else if (javaType == boolean.class) { methodHandle = WRITE_BOOLEAN; } else if (!javaType.isPrimitive()) { methodHandle = WRITE_OBJECT; } else { throw new IllegalArgumentException("Unknown java type " + javaType + " from type " + type); } return methodHandle.bindTo(type); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy