Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.trino.sql.gen.BytecodeUtils Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.gen;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.instruction.LabelNode;
import io.airlift.slice.Slice;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.function.InOut;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.type.Type;
import io.trino.sql.gen.InputReferenceCompiler.InputReferenceNode;
import io.trino.type.FunctionType;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.bytecode.OpCode.NOP;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue;
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.sql.gen.Bootstrap.BOOTSTRAP_METHOD;
import static java.lang.String.format;
public final class BytecodeUtils
{
private BytecodeUtils() {}
public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class> returnType, Class>... stackArgsToPop)
{
return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), false);
}
public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class> returnType, Iterable extends Class>> stackArgsToPop)
{
return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), false);
}
public static BytecodeNode ifWasNullClearPopAndGoto(Scope scope, LabelNode label, Class> returnType, Class>... stackArgsToPop)
{
return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), true);
}
public static BytecodeNode handleNullValue(
Scope scope,
LabelNode label,
Class> returnType,
List> stackArgsToPop,
boolean clearNullFlag)
{
Variable wasNull = scope.getVariable("wasNull");
BytecodeBlock nullCheck = new BytecodeBlock()
.setDescription("ifWasNullGoto")
.append(wasNull);
String clearComment = null;
if (clearNullFlag) {
nullCheck.append(wasNull.set(constantFalse()));
clearComment = "clear wasNull";
}
BytecodeBlock isNull = new BytecodeBlock();
for (Class> parameterType : stackArgsToPop) {
isNull.pop(parameterType);
}
isNull.pushJavaDefault(returnType);
String loadDefaultComment;
loadDefaultComment = "loadJavaDefault(" + returnType.getName() + ")";
isNull.gotoLabel(label);
String popComment = null;
if (!stackArgsToPop.isEmpty()) {
popComment = format("pop(%s)", Joiner.on(", ").join(stackArgsToPop));
}
return new IfStatement("if wasNull then %s", Joiner.on(", ").skipNulls().join(clearComment, popComment, loadDefaultComment, "goto " + label.getLabel()))
.condition(nullCheck)
.ifTrue(isNull);
}
public static BytecodeNode unboxPrimitive(Class> unboxedType)
{
BytecodeBlock block = new BytecodeBlock().comment("unbox primitive");
if (unboxedType == long.class) {
return block.invokeVirtual(Long.class, "longValue", long.class);
}
if (unboxedType == double.class) {
return block.invokeVirtual(Double.class, "doubleValue", double.class);
}
if (unboxedType == boolean.class) {
return block.invokeVirtual(Boolean.class, "booleanValue", boolean.class);
}
throw new UnsupportedOperationException("not yet implemented: " + unboxedType);
}
public static BytecodeExpression loadConstant(CallSiteBinder callSiteBinder, Object constant, Class> type)
{
Binding binding = callSiteBinder.bind(MethodHandles.constant(type, constant));
return loadConstant(binding);
}
public static BytecodeExpression loadConstant(Binding binding)
{
return invokeDynamic(
BOOTSTRAP_METHOD,
ImmutableList.of(binding.getBindingId()),
"constant_" + binding.getBindingId(),
binding.getType().returnType());
}
public static BytecodeNode generateInvocation(
Scope scope,
ResolvedFunction resolvedFunction,
FunctionManager functionManager,
List arguments,
CallSiteBinder binder)
{
return generateInvocation(
scope,
resolvedFunction.getSignature().getName().getFunctionName(),
resolvedFunction.getFunctionNullability(),
invocationConvention -> functionManager.getScalarFunctionImplementation(resolvedFunction, invocationConvention),
arguments,
binder);
}
public static BytecodeNode generateInvocation(
Scope scope,
String functionName,
FunctionNullability functionNullability,
Function functionImplementationProvider,
List arguments,
CallSiteBinder binder)
{
return generateFullInvocation(
scope,
functionName,
functionNullability,
Collections.nCopies(arguments.size(), false),
functionImplementationProvider,
instanceFactory -> {
throw new IllegalArgumentException("Simple method invocation can not be used with functions that require an instance factory");
},
arguments.stream()
.map(BytecodeUtils::simpleArgument)
.collect(toImmutableList()),
binder);
}
private static Function>, BytecodeNode> simpleArgument(BytecodeNode argument)
{
return lambdaInterface -> {
checkArgument(lambdaInterface.isEmpty(), "Simple method invocation can not be used with functions that have lambda arguments");
return argument;
};
}
public static BytecodeNode generateFullInvocation(
Scope scope,
ResolvedFunction resolvedFunction,
FunctionManager functionManager,
Function instanceFactory,
List>, BytecodeNode>> argumentCompilers,
CallSiteBinder binder)
{
return generateFullInvocation(
scope,
resolvedFunction.getSignature().getName().getFunctionName(),
resolvedFunction.getFunctionNullability(),
resolvedFunction.getSignature().getArgumentTypes().stream()
.map(FunctionType.class::isInstance)
.collect(toImmutableList()),
invocationConvention -> functionManager.getScalarFunctionImplementation(resolvedFunction, invocationConvention),
instanceFactory,
argumentCompilers,
binder);
}
private static BytecodeNode generateFullInvocation(
Scope scope,
String functionName,
FunctionNullability functionNullability,
List argumentIsFunctionType,
Function functionImplementationProvider,
Function instanceFactory,
List>, BytecodeNode>> argumentCompilers,
CallSiteBinder binder)
{
verify(argumentIsFunctionType.size() == argumentCompilers.size());
List argumentConventions = new ArrayList<>();
List arguments = new ArrayList<>();
for (int i = 0; i < argumentIsFunctionType.size(); i++) {
if (argumentIsFunctionType.get(i)) {
argumentConventions.add(FUNCTION);
arguments.add(null);
}
else {
BytecodeNode argument = argumentCompilers.get(i).apply(Optional.empty());
argumentConventions.add(getPreferredArgumentConvention(argument, argumentCompilers.size(), functionNullability.isArgumentNullable(i)));
arguments.add(argument);
}
}
InvocationConvention invocationConvention = new InvocationConvention(
argumentConventions,
functionNullability.isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL,
true,
true);
ScalarFunctionImplementation implementation = functionImplementationProvider.apply(invocationConvention);
Binding binding = binder.bind(implementation.getMethodHandle());
LabelNode end = new LabelNode("end");
BytecodeBlock block = new BytecodeBlock()
.setDescription("invoke " + functionName);
Optional instance = implementation.getInstanceFactory()
.map(instanceFactory);
// Index of current parameter in the MethodHandle
int currentParameterIndex = 0;
// Index of parameter (without @IsNull) in Trino function
int realParameterIndex = 0;
// Index of function argument types
int lambdaArgumentIndex = 0;
MethodType methodType = binding.getType();
Class> returnType = methodType.returnType();
Class> unboxedReturnType = Primitives.unwrap(returnType);
List> stackTypes = new ArrayList<>();
boolean instanceIsBound = false;
while (currentParameterIndex < methodType.parameterArray().length) {
Class> type = methodType.parameterArray()[currentParameterIndex];
stackTypes.add(type);
if (instance.isPresent() && !instanceIsBound) {
checkState(type.equals(implementation.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter");
block.append(instance.get());
instanceIsBound = true;
}
else if (type == ConnectorSession.class) {
block.append(scope.getVariable("session"));
}
else {
switch (invocationConvention.getArgumentConvention(realParameterIndex)) {
case NEVER_NULL:
block.append(arguments.get(realParameterIndex));
checkArgument(!Primitives.isWrapperType(type), "Non-nullable argument must not be primitive wrapper type");
block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
break;
case NULL_FLAG:
block.append(arguments.get(realParameterIndex));
block.append(scope.getVariable("wasNull"));
block.append(scope.getVariable("wasNull").set(constantFalse()));
stackTypes.add(boolean.class);
currentParameterIndex++;
break;
case BOXED_NULLABLE:
block.append(arguments.get(realParameterIndex));
block.append(boxPrimitiveIfNecessary(scope, type));
block.append(scope.getVariable("wasNull").set(constantFalse()));
break;
case BLOCK_POSITION:
InputReferenceNode inputReferenceNode = (InputReferenceNode) arguments.get(realParameterIndex);
block.append(inputReferenceNode.produceBlockAndPosition());
stackTypes.add(int.class);
if (!functionNullability.isArgumentNullable(realParameterIndex)) {
block.append(scope.getVariable("wasNull").set(inputReferenceNode.blockAndPositionIsNull()));
block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
}
currentParameterIndex++;
break;
case IN_OUT:
block.append(arguments.get(realParameterIndex));
if (!functionNullability.isArgumentNullable(realParameterIndex)) {
block.append(arguments.get(realParameterIndex));
block.invokeVirtual(InOut.class, "isNull", boolean.class);
block.putVariable(scope.getVariable("wasNull"));
block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
}
currentParameterIndex++;
break;
case FUNCTION:
Class> lambdaInterface = implementation.getLambdaInterfaces().get(lambdaArgumentIndex);
block.append(argumentCompilers.get(realParameterIndex).apply(Optional.of(lambdaInterface)));
lambdaArgumentIndex++;
break;
default:
throw new UnsupportedOperationException(format("Unsupported argument convention type: %s", invocationConvention.getArgumentConvention(realParameterIndex)));
}
realParameterIndex++;
}
currentParameterIndex++;
}
block.append(invoke(binding, functionName));
if (functionNullability.isReturnNullable()) {
block.append(unboxPrimitiveIfNecessary(scope, returnType));
}
block.visitLabel(end);
return block;
}
private static InvocationArgumentConvention getPreferredArgumentConvention(BytecodeNode argument, int argumentCount, boolean nullable)
{
// a Java function can only have 255 arguments, so if the count is low use block position or boxed nullable as they are more efficient
if (argumentCount <= 64) {
if (argument instanceof InputReferenceNode) {
return BLOCK_POSITION;
}
if (nullable) {
return NULL_FLAG;
}
}
return nullable ? BOXED_NULLABLE : NEVER_NULL;
}
public static BytecodeBlock unboxPrimitiveIfNecessary(Scope scope, Class> boxedType)
{
BytecodeBlock block = new BytecodeBlock();
LabelNode end = new LabelNode("end");
Class> unboxedType = Primitives.unwrap(boxedType);
Variable wasNull = scope.getVariable("wasNull");
if (unboxedType.isPrimitive()) {
LabelNode notNull = new LabelNode("notNull");
block.dup(boxedType)
.ifNotNullGoto(notNull)
.append(wasNull.set(constantTrue()))
.comment("swap boxed null with unboxed default")
.pop(boxedType)
.pushJavaDefault(unboxedType)
.gotoLabel(end)
.visitLabel(notNull)
.append(unboxPrimitive(unboxedType));
}
else {
block.dup(boxedType)
.ifNotNullGoto(end)
.append(wasNull.set(constantTrue()));
}
block.visitLabel(end);
return block;
}
public static BytecodeNode boxPrimitiveIfNecessary(Scope scope, Class> type)
{
checkArgument(!type.isPrimitive(), "cannot box into primitive type");
if (!Primitives.isWrapperType(type)) {
return NOP;
}
BytecodeBlock notNull = new BytecodeBlock().comment("box primitive");
Class> expectedCurrentStackType;
if (type == Long.class) {
notNull.invokeStatic(Long.class, "valueOf", Long.class, long.class);
expectedCurrentStackType = long.class;
}
else if (type == Double.class) {
notNull.invokeStatic(Double.class, "valueOf", Double.class, double.class);
expectedCurrentStackType = double.class;
}
else if (type == Boolean.class) {
notNull.invokeStatic(Boolean.class, "valueOf", Boolean.class, boolean.class);
expectedCurrentStackType = boolean.class;
}
else {
throw new UnsupportedOperationException("not yet implemented: " + type);
}
BytecodeBlock condition = new BytecodeBlock().append(scope.getVariable("wasNull"));
BytecodeBlock wasNull = new BytecodeBlock()
.pop(expectedCurrentStackType)
.pushNull()
.checkCast(type);
return new IfStatement()
.condition(condition)
.ifTrue(wasNull)
.ifFalse(notNull);
}
public static BytecodeExpression invoke(Binding binding, String name, BytecodeExpression... parameters)
{
return invoke(binding, name, ImmutableList.copyOf(parameters));
}
public static BytecodeExpression invoke(Binding binding, String name, List parameters)
{
// ensure that name doesn't have a special characters
return invokeDynamic(BOOTSTRAP_METHOD, ImmutableList.of(binding.getBindingId()), sanitizeName(name), binding.getType(), parameters);
}
public static BytecodeExpression invoke(Binding binding, BoundSignature signature)
{
return invoke(binding, signature.getName().getFunctionName());
}
/**
* Replace characters that are not safe to use in a JVM identifier.
*/
public static String sanitizeName(String name)
{
return name.replaceAll("[^A-Za-z0-9_$]", "_");
}
public static BytecodeNode generateWrite(CallSiteBinder callSiteBinder, Scope scope, Variable wasNullVariable, Type type)
{
Class> valueJavaType = type.getJavaType();
if (!valueJavaType.isPrimitive() && valueJavaType != Slice.class) {
valueJavaType = Object.class;
}
String methodName = "write" + Primitives.wrap(valueJavaType).getSimpleName();
// the stack contains [output, value]
// We should be able to insert the code to get the output variable and compute the value
// at the right place instead of assuming they are in the stack. We should also not need to
// use temp variables to re-shuffle the stack to the right shape before Type.writeXXX is called
// Unfortunately, because of the assumptions made by try_cast, we can't get around it yet.
// TODO: clean up once try_cast is fixed
Variable tempValue = scope.createTempVariable(valueJavaType);
Variable tempOutput = scope.createTempVariable(BlockBuilder.class);
return new BytecodeBlock()
.comment("if (wasNull)")
.append(new IfStatement()
.condition(wasNullVariable)
.ifTrue(new BytecodeBlock()
.comment("output.appendNull();")
.pop(valueJavaType)
.invokeInterface(BlockBuilder.class, "appendNull", BlockBuilder.class)
.pop())
.ifFalse(new BytecodeBlock()
.comment("%s.%s(output, %s)", type.getTypeSignature(), methodName, valueJavaType.getSimpleName())
.putVariable(tempValue)
.putVariable(tempOutput)
.append(loadConstant(callSiteBinder.bind(type, Type.class)))
.getVariable(tempOutput)
.getVariable(tempValue)
.invokeInterface(Type.class, methodName, void.class, BlockBuilder.class, valueJavaType)));
}
}