Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.trino.plugin.pinot.query.PinotExpressionRewriter Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.pinot.query;
import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Match;
import io.trino.matching.Pattern;
import io.trino.plugin.pinot.PinotException;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.SchemaTableName;
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.core.operator.transform.transformer.datetime.BaseDateTimeTransformer;
import org.apache.pinot.core.operator.transform.transformer.datetime.DateTimeTransformerFactory;
import org.apache.pinot.core.operator.transform.transformer.datetime.EpochToEpochTransformer;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Maps.immutableEnumMap;
import static io.trino.plugin.pinot.PinotErrorCode.PINOT_EXCEPTION;
import static io.trino.plugin.pinot.query.PinotPatterns.WILDCARD;
import static io.trino.plugin.pinot.query.PinotPatterns.aggregationFunction;
import static io.trino.plugin.pinot.query.PinotPatterns.aggregationFunctionType;
import static io.trino.plugin.pinot.query.PinotPatterns.expression;
import static io.trino.plugin.pinot.query.PinotPatterns.expressionType;
import static io.trino.plugin.pinot.query.PinotPatterns.function;
import static io.trino.plugin.pinot.query.PinotPatterns.identifier;
import static io.trino.plugin.pinot.query.PinotPatterns.singleInput;
import static io.trino.plugin.pinot.query.PinotPatterns.transformFunction;
import static io.trino.plugin.pinot.query.PinotPatterns.transformFunctionType;
import static io.trino.plugin.pinot.query.PinotSqlFormatter.getColumnHandle;
import static io.trino.plugin.pinot.query.PinotTransformFunctionTypeResolver.getTransformFunctionType;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static org.apache.pinot.common.function.TransformFunctionType.DATE_TIME_CONVERT;
import static org.apache.pinot.common.function.TransformFunctionType.DATE_TRUNC;
import static org.apache.pinot.common.function.TransformFunctionType.TIME_CONVERT;
import static org.apache.pinot.common.request.Literal.stringValue;
import static org.apache.pinot.common.request.context.ExpressionContext.Type.FUNCTION;
import static org.apache.pinot.common.request.context.ExpressionContext.Type.IDENTIFIER;
import static org.apache.pinot.common.request.context.ExpressionContext.Type.LITERAL;
import static org.apache.pinot.common.request.context.ExpressionContext.forFunction;
import static org.apache.pinot.common.request.context.ExpressionContext.forIdentifier;
import static org.apache.pinot.common.request.context.ExpressionContext.forLiteralContext;
import static org.apache.pinot.core.operator.transform.function.DateTruncTransformFunction.EXAMPLE_INVOCATION;
import static org.apache.pinot.core.operator.transform.transformer.timeunit.TimeUnitTransformerFactory.getTimeUnitTransformer;
import static org.apache.pinot.segment.spi.AggregationFunctionType.COUNT;
import static org.apache.pinot.segment.spi.AggregationFunctionType.getAggregationFunctionType;
public class PinotExpressionRewriter
{
private static final Map> FUNCTION_RULE_MAP;
private static final Map> AGGREGATION_FUNCTION_RULE_MAP;
private static final RewriteRule DEFAULT_REWRITE_RULE = new DefaultRewriteRule();
private PinotExpressionRewriter() {}
static {
Map> functionMap = new HashMap<>();
functionMap.put(DATE_TIME_CONVERT, new DateTimeConvertRewriteRule());
functionMap.put(TIME_CONVERT, new TimeConvertRewriteRule());
functionMap.put(DATE_TRUNC, new DateTruncRewriteRule());
FUNCTION_RULE_MAP = immutableEnumMap(functionMap);
Map> aggregationFunctionMap = new HashMap<>();
aggregationFunctionMap.put(COUNT, new CountStarRewriteRule());
AGGREGATION_FUNCTION_RULE_MAP = immutableEnumMap(aggregationFunctionMap);
}
public static ExpressionContext rewriteExpression(SchemaTableName schemaTableName, ExpressionContext expressionContext, Map columnHandles)
{
requireNonNull(expressionContext, "expressionContext is null");
Context context = new Context() {
@Override
public SchemaTableName getSchemaTableName()
{
return schemaTableName;
}
@Override
public Map getColumnHandles()
{
return columnHandles;
}
};
return rewriteExpression(expressionContext, context);
}
private static ExpressionContext rewriteExpression(ExpressionContext expressionContext, Context context)
{
return switch (expressionContext.getType()) {
case LITERAL -> expressionContext;
case IDENTIFIER -> forIdentifier(getColumnHandle(expressionContext.getIdentifier(), context.getSchemaTableName(), context.getColumnHandles()).getColumnName());
case FUNCTION -> forFunction(rewriteFunction(expressionContext.getFunction(), context));
};
}
private static FunctionContext rewriteFunction(FunctionContext functionContext, Context context)
{
Optional result = Optional.empty();
if (functionContext.getType() == FunctionContext.Type.TRANSFORM) {
RewriteRule rule = FUNCTION_RULE_MAP.get(getTransformFunctionType(functionContext).orElseThrow());
if (rule != null) {
result = applyRule(rule, functionContext, context);
}
}
else {
checkState(functionContext.getType() == FunctionContext.Type.AGGREGATION, "Unexpected function type for '%s'", functionContext);
RewriteRule rule = AGGREGATION_FUNCTION_RULE_MAP.get(getAggregationFunctionType(functionContext.getFunctionName()));
if (rule != null) {
result = applyRule(rule, functionContext, context);
}
}
if (result.isPresent()) {
return result.get();
}
result = applyRule(DEFAULT_REWRITE_RULE, functionContext, context);
if (result.isPresent()) {
return result.get();
}
throw new PinotException(PINOT_EXCEPTION, Optional.empty(), format("Unsupported function expression '%s'", functionContext));
}
private static Optional applyRule(RewriteRule rule, T object, Context context)
{
Iterator iterator = rule.getPattern().match(object).iterator();
while (iterator.hasNext()) {
Match match = iterator.next();
return Optional.of(rule.rewrite(object, match.captures(), context));
}
return Optional.empty();
}
private static class DateTimeConvertRewriteRule
implements RewriteRule
{
@Override
public Pattern getPattern()
{
return transformFunction().with(transformFunctionType().equalTo(DATE_TIME_CONVERT));
}
@Override
public FunctionContext rewrite(FunctionContext object, Captures captures, Context context)
{
// Extracted from org.apache.pinot.core.operator.transform.function.DateTimeConversionTransformFunction
// The first argument must be an identifier or function and the 2nd, 3rd and 4th arguments must be literals
verify(object.getArguments().size() == 4);
verifyIsIdentifierOrFunction(object.getArguments().get(0));
verifyTailArgumentsAllLiteral(object.getArguments());
ImmutableList.Builder argumentsBuilder = ImmutableList.builder();
argumentsBuilder.add(rewriteExpression(object.getArguments().get(0), context));
String inputFormat = object.getArguments().get(1).getLiteral().getValue().toString().toUpperCase(ENGLISH);
argumentsBuilder.add(forLiteralContext(stringValue(inputFormat)));
String outputFormat = object.getArguments().get(2).getLiteral().getValue().toString().toUpperCase(ENGLISH);
argumentsBuilder.add(forLiteralContext(stringValue(outputFormat)));
String granularity = object.getArguments().get(3).getLiteral().getValue().toString().toUpperCase(ENGLISH);
BaseDateTimeTransformer, ?> dateTimeTransformer = DateTimeTransformerFactory.getDateTimeTransformer(inputFormat, outputFormat, granularity);
// Even if the format is valid, make sure it is not a simple date format: format characters can be ambiguous due to lower casing
checkState(dateTimeTransformer instanceof EpochToEpochTransformer, "Unsupported date format: simple date format not supported");
argumentsBuilder.add(forLiteralContext(stringValue(granularity)));
return new FunctionContext(object.getType(), object.getFunctionName(), argumentsBuilder.build());
}
}
private static class TimeConvertRewriteRule
implements RewriteRule
{
@Override
public Pattern getPattern()
{
return transformFunction().with(transformFunctionType().equalTo(TIME_CONVERT));
}
@Override
public FunctionContext rewrite(FunctionContext object, Captures captures, Context context)
{
// Extracted from org.apache.pinot.core.operator.transform.function.DateTimeConversionTransformFunction
// The first argument must be an identifier or function and the 2nd, and 3rd arguments must be literals
verify(object.getArguments().size() == 3);
verifyIsIdentifierOrFunction(object.getArguments().get(0));
verifyTailArgumentsAllLiteral(object.getArguments());
ImmutableList.Builder argumentsBuilder = ImmutableList.builder();
argumentsBuilder.add(rewriteExpression(object.getArguments().get(0), context));
String inputTimeUnitArgument = object.getArguments().get(1).getLiteral().getValue().toString().toUpperCase(ENGLISH);
TimeUnit inputTimeUnit = TimeUnit.valueOf(inputTimeUnitArgument);
String outputTimeUnitArgument = object.getArguments().get(2).getLiteral().getValue().toString().toUpperCase(ENGLISH);
// Check that this is a valid time unit transform
getTimeUnitTransformer(inputTimeUnit, outputTimeUnitArgument);
argumentsBuilder.add(forLiteralContext(stringValue(inputTimeUnitArgument)));
argumentsBuilder.add(forLiteralContext(stringValue(outputTimeUnitArgument)));
return new FunctionContext(object.getType(), object.getFunctionName(), argumentsBuilder.build());
}
}
private static class DateTruncRewriteRule
implements RewriteRule
{
@Override
public Pattern getPattern()
{
return transformFunction().with(transformFunctionType().equalTo(DATE_TRUNC));
}
@Override
public FunctionContext rewrite(FunctionContext object, Captures captures, Context context)
{
// Extracted from org.apache.pinot.core.operator.transform.function.DateTruncTransformFunction
List arguments = object.getArguments();
checkState(arguments.size() >= 2 && arguments.size() <= 5,
"Between two to five arguments are required, example: %s", EXAMPLE_INVOCATION);
ImmutableList.Builder argumentsBuilder = ImmutableList.builder();
checkState(arguments.get(0).getType() == LITERAL, "First argument must be a literal");
String unit = arguments.get(0).getLiteral().getValue().toString().toLowerCase(ENGLISH);
argumentsBuilder.add(forLiteralContext(stringValue(unit)));
verifyIsIdentifierOrFunction(object.getArguments().get(1));
ExpressionContext valueArgument = rewriteExpression(arguments.get(1), context);
argumentsBuilder.add(valueArgument);
if (arguments.size() >= 3) {
checkState(arguments.get(2).getType() == LITERAL, "Unexpected 3rd argument: '%s'", arguments.get(2));
String inputTimeUnitArgument = arguments.get(2).getLiteral().getValue().toString().toUpperCase(ENGLISH);
// Ensure this is a valid TimeUnit
TimeUnit inputTimeUnit = TimeUnit.valueOf(inputTimeUnitArgument);
argumentsBuilder.add(forLiteralContext(stringValue(inputTimeUnit.name())));
if (arguments.size() >= 4) {
checkState(arguments.get(3).getType() == LITERAL, "Unexpected 4th argument '%s'", arguments.get(3));
// Time zone is lower cased inside Pinot
argumentsBuilder.add(arguments.get(3));
if (arguments.size() >= 5) {
checkState(arguments.get(4).getType() == LITERAL, "Unexpected 5th argument: '%s'", arguments.get(4));
String outputTimeUnitArgument = arguments.get(4).getLiteral().getValue().toString().toUpperCase(ENGLISH);
// Ensure this is a valid TimeUnit
TimeUnit outputTimeUnit = TimeUnit.valueOf(outputTimeUnitArgument);
argumentsBuilder.add(forLiteralContext(stringValue(outputTimeUnit.name())));
}
}
}
return new FunctionContext(object.getType(), object.getFunctionName(), argumentsBuilder.build());
}
}
private static class CountStarRewriteRule
implements RewriteRule
{
@Override
public Pattern getPattern()
{
return aggregationFunction()
.with(aggregationFunctionType().equalTo(COUNT))
.with(singleInput().matching(expression()
.with(expressionType().equalTo(IDENTIFIER))
.with(identifier().equalTo(WILDCARD))));
}
@Override
public FunctionContext rewrite(FunctionContext object, Captures captures, Context context)
{
return object;
}
}
private static class DefaultRewriteRule
implements RewriteRule
{
@Override
public Pattern getPattern()
{
return function();
}
@Override
public FunctionContext rewrite(FunctionContext object, Captures captures, Context context)
{
List arguments = object.getArguments().stream().map(argument -> rewriteExpression(argument, context))
.collect(toImmutableList());
return new FunctionContext(object.getType(), object.getFunctionName(), arguments);
}
}
private static void verifyIsIdentifierOrFunction(ExpressionContext expressionContext)
{
verify(expressionContext.getType() == IDENTIFIER || expressionContext.getType() == FUNCTION);
}
private static void verifyTailArgumentsAllLiteral(List arguments)
{
arguments.stream().skip(1)
.forEach(argument -> verify(argument.getType() == LITERAL));
}
private interface Context
{
SchemaTableName getSchemaTableName();
Map getColumnHandles();
}
private interface RewriteRule
{
Pattern getPattern();
T rewrite(T object, Captures captures, Context context);
}
}