org.apache.drill.exec.planner.logical.DrillOptiq Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.planner.logical;
import java.math.BigDecimal;
import java.util.GregorianCalendar;
import java.util.LinkedList;
import java.util.List;
import org.apache.drill.common.exceptions.UserException;
import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.FieldReference;
import org.apache.drill.common.expression.FunctionCallFactory;
import org.apache.drill.common.expression.IfExpression;
import org.apache.drill.common.expression.IfExpression.IfCondition;
import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.common.expression.NullExpression;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.expression.TypedNullConstant;
import org.apache.drill.common.expression.ValueExpressions;
import org.apache.drill.common.expression.ValueExpressions.QuotedString;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.common.types.TypeProtos.MajorType;
import org.apache.drill.common.types.TypeProtos.MinorType;
import org.apache.drill.common.types.Types;
import org.apache.drill.exec.planner.StarColumnHelper;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexDynamicParam;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexRangeRef;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlSyntax;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.NlsString;
import com.google.common.collect.Lists;
import org.apache.drill.exec.planner.physical.PlannerSettings;
import org.apache.drill.exec.work.ExecErrorConstants;
/**
* Utilities for Drill's planner.
*/
public class DrillOptiq {
public static final String UNSUPPORTED_REX_NODE_ERROR = "Cannot convert RexNode to equivalent Drill expression. ";
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillOptiq.class);
/**
* Converts a tree of {@link RexNode} operators into a scalar expression in Drill syntax.
*/
public static LogicalExpression toDrill(DrillParseContext context, RelNode input, RexNode expr) {
final RexToDrill visitor = new RexToDrill(context, input);
return expr.accept(visitor);
}
private static class RexToDrill extends RexVisitorImpl {
private final RelNode input;
private final DrillParseContext context;
RexToDrill(DrillParseContext context, RelNode input) {
super(true);
this.context = context;
this.input = input;
}
@Override
public LogicalExpression visitInputRef(RexInputRef inputRef) {
final int index = inputRef.getIndex();
final RelDataTypeField field = input.getRowType().getFieldList().get(index);
return FieldReference.getWithQuotedRef(field.getName());
}
@Override
public LogicalExpression visitCall(RexCall call) {
// logger.debug("RexCall {}, {}", call);
final SqlSyntax syntax = call.getOperator().getSyntax();
switch (syntax) {
case BINARY:
logger.debug("Binary");
final String funcName = call.getOperator().getName().toLowerCase();
return doFunction(call, funcName);
case FUNCTION:
case FUNCTION_ID:
logger.debug("Function");
return getDrillFunctionFromOptiqCall(call);
case POSTFIX:
logger.debug("Postfix");
switch(call.getKind()){
case IS_NOT_NULL:
case IS_NOT_TRUE:
case IS_NOT_FALSE:
case IS_NULL:
case IS_TRUE:
case IS_FALSE:
case OTHER:
return FunctionCallFactory.createExpression(call.getOperator().getName().toLowerCase(),
ExpressionPosition.UNKNOWN, call.getOperands().get(0).accept(this));
}
throw new AssertionError("todo: implement syntax " + syntax + "(" + call + ")");
case PREFIX:
logger.debug("Prefix");
LogicalExpression arg = call.getOperands().get(0).accept(this);
switch(call.getKind()){
case NOT:
return FunctionCallFactory.createExpression(call.getOperator().getName().toLowerCase(),
ExpressionPosition.UNKNOWN, arg);
case MINUS_PREFIX:
final RexBuilder builder = input.getCluster().getRexBuilder();
final List operands = Lists.newArrayList();
operands.add(builder.makeExactLiteral(new BigDecimal(-1)));
operands.add(call.getOperands().get(0));
return visitCall((RexCall) builder.makeCall(
SqlStdOperatorTable.MULTIPLY,
operands));
}
throw new AssertionError("todo: implement syntax " + syntax + "(" + call + ")");
case SPECIAL:
logger.debug("Special");
switch(call.getKind()){
case CAST:
return getDrillCastFunctionFromOptiq(call);
case LIKE:
case SIMILAR:
return getDrillFunctionFromOptiqCall(call);
case CASE:
List caseArgs = Lists.newArrayList();
for(RexNode r : call.getOperands()){
caseArgs.add(r.accept(this));
}
caseArgs = Lists.reverse(caseArgs);
// number of arguements are always going to be odd, because
// Optiq adds "null" for the missing else expression at the end
assert caseArgs.size()%2 == 1;
LogicalExpression elseExpression = caseArgs.get(0);
for (int i=1; i args = Lists.newArrayList();
for(RexNode r : call.getOperands()){
args.add(r.accept(this));
}
if (FunctionCallFactory.isBooleanOperator(funcName)) {
LogicalExpression func = FunctionCallFactory.createBooleanOperator(funcName, args);
return func;
} else {
args = Lists.reverse(args);
LogicalExpression lastArg = args.get(0);
for(int i = 1; i < args.size(); i++){
lastArg = FunctionCallFactory.createExpression(funcName, Lists.newArrayList(args.get(i), lastArg));
}
return lastArg;
}
}
private LogicalExpression doUnknown(RexNode o){
// raise an error
throw UserException.planError().message(UNSUPPORTED_REX_NODE_ERROR +
"RexNode Class: %s, RexNode Digest: %s", o.getClass().getName(), o.toString()).build(logger);
}
@Override
public LogicalExpression visitLocalRef(RexLocalRef localRef) {
return doUnknown(localRef);
}
@Override
public LogicalExpression visitOver(RexOver over) {
return doUnknown(over);
}
@Override
public LogicalExpression visitCorrelVariable(RexCorrelVariable correlVariable) {
return doUnknown(correlVariable);
}
@Override
public LogicalExpression visitDynamicParam(RexDynamicParam dynamicParam) {
return doUnknown(dynamicParam);
}
@Override
public LogicalExpression visitRangeRef(RexRangeRef rangeRef) {
return doUnknown(rangeRef);
}
@Override
public LogicalExpression visitFieldAccess(RexFieldAccess fieldAccess) {
return super.visitFieldAccess(fieldAccess);
}
private LogicalExpression getDrillCastFunctionFromOptiq(RexCall call){
LogicalExpression arg = call.getOperands().get(0).accept(this);
MajorType castType = null;
switch(call.getType().getSqlTypeName().getName()){
case "VARCHAR":
case "CHAR":
castType = Types.required(MinorType.VARCHAR).toBuilder().setWidth(call.getType().getPrecision()).build();
break;
case "INTEGER": castType = Types.required(MinorType.INT); break;
case "FLOAT": castType = Types.required(MinorType.FLOAT4); break;
case "DOUBLE": castType = Types.required(MinorType.FLOAT8); break;
case "DECIMAL":
if (context.getPlannerSettings().getOptions().
getOption(PlannerSettings.ENABLE_DECIMAL_DATA_TYPE_KEY).bool_val == false ) {
throw UserException
.unsupportedError()
.message(ExecErrorConstants.DECIMAL_DISABLE_ERR_MSG)
.build(logger);
}
int precision = call.getType().getPrecision();
int scale = call.getType().getScale();
if (precision <= 9) {
castType = TypeProtos.MajorType.newBuilder().setMinorType(MinorType.DECIMAL9).setPrecision(precision).setScale(scale).build();
} else if (precision <= 18) {
castType = TypeProtos.MajorType.newBuilder().setMinorType(MinorType.DECIMAL18).setPrecision(precision).setScale(scale).build();
} else if (precision <= 28) {
// Inject a cast to SPARSE before casting to the dense type.
castType = TypeProtos.MajorType.newBuilder().setMinorType(MinorType.DECIMAL28SPARSE).setPrecision(precision).setScale(scale).build();
} else if (precision <= 38) {
castType = TypeProtos.MajorType.newBuilder().setMinorType(MinorType.DECIMAL38SPARSE).setPrecision(precision).setScale(scale).build();
} else {
throw new UnsupportedOperationException("Only Decimal types with precision range 0 - 38 is supported");
}
break;
case "INTERVAL_YEAR_MONTH": castType = Types.required(MinorType.INTERVALYEAR); break;
case "INTERVAL_DAY_TIME": castType = Types.required(MinorType.INTERVALDAY); break;
case "BOOLEAN": castType = Types.required(MinorType.BIT); break;
case "BINARY": castType = Types.required(MinorType.VARBINARY).toBuilder().setWidth(call.getType().getPrecision()).build(); break;
case "ANY": return arg; // Type will be same as argument.
default: castType = Types.required(MinorType.valueOf(call.getType().getSqlTypeName().getName()));
}
return FunctionCallFactory.createCast(castType, ExpressionPosition.UNKNOWN, arg);
}
private LogicalExpression getDrillFunctionFromOptiqCall(RexCall call) {
List args = Lists.newArrayList();
for(RexNode n : call.getOperands()){
args.add(n.accept(this));
}
int argsSize = args.size();
String functionName = call.getOperator().getName().toLowerCase();
// TODO: once we have more function rewrites and a patter emerges from different rewrites, factor this out in a better fashion
/* Rewrite extract functions in the following manner
* extract(year, date '2008-2-23') ---> extractYear(date '2008-2-23')
*/
if (functionName.equals("extract")) {
// Assert that the first argument to extract is a QuotedString
assert args.get(0) instanceof ValueExpressions.QuotedString;
// Get the unit of time to be extracted
String timeUnitStr = ((ValueExpressions.QuotedString)args.get(0)).value;
switch (timeUnitStr){
case ("YEAR"):
case ("MONTH"):
case ("DAY"):
case ("HOUR"):
case ("MINUTE"):
case ("SECOND"):
String functionPostfix = timeUnitStr.substring(0, 1).toUpperCase() + timeUnitStr.substring(1).toLowerCase();
functionName += functionPostfix;
return FunctionCallFactory.createExpression(functionName, args.subList(1, 2));
default:
throw new UnsupportedOperationException("extract function supports the following time units: YEAR, MONTH, DAY, HOUR, MINUTE, SECOND");
}
} else if (functionName.equals("trim")) {
String trimFunc = null;
List trimArgs = Lists.newArrayList();
assert args.get(0) instanceof ValueExpressions.QuotedString;
switch (((ValueExpressions.QuotedString)args.get(0)).value.toUpperCase()) {
case "LEADING":
trimFunc = "ltrim";
break;
case "TRAILING":
trimFunc = "rtrim";
break;
case "BOTH":
trimFunc = "btrim";
break;
default:
assert 1 == 0;
}
trimArgs.add(args.get(2));
trimArgs.add(args.get(1));
return FunctionCallFactory.createExpression(trimFunc, trimArgs);
} else if (functionName.equals("ltrim") || functionName.equals("rtrim") || functionName.equals("btrim")) {
if (argsSize == 1) {
args.add(ValueExpressions.getChar(" "));
}
return FunctionCallFactory.createExpression(functionName, args);
} else if (functionName.equals("date_part")) {
// Rewrite DATE_PART functions as extract functions
// assert that the function has exactly two arguments
assert argsSize == 2;
/* Based on the first input to the date_part function we rewrite the function as the
* appropriate extract function. For example
* date_part('year', date '2008-2-23') ------> extractYear(date '2008-2-23')
*/
assert args.get(0) instanceof QuotedString;
QuotedString extractString = (QuotedString) args.get(0);
String functionPostfix = extractString.value.substring(0, 1).toUpperCase() + extractString.value.substring(1).toLowerCase();
return FunctionCallFactory.createExpression("extract" + functionPostfix, args.subList(1, 2));
} else if (functionName.equals("concat")) {
if (argsSize == 1) {
/*
* We treat concat with one argument as a special case. Since we don't have a function
* implementation of concat that accepts one argument. We simply add another dummy argument
* (empty string literal) to the list of arguments.
*/
List concatArgs = new LinkedList<>(args);
concatArgs.add(new QuotedString("", ExpressionPosition.UNKNOWN));
return FunctionCallFactory.createExpression(functionName, concatArgs);
} else if (argsSize > 2) {
List concatArgs = Lists.newArrayList();
/* stack concat functions on top of each other if we have more than two arguments
* Eg: concat(col1, col2, col3) => concat(concat(col1, col2), col3)
*/
concatArgs.add(args.get(0));
concatArgs.add(args.get(1));
LogicalExpression first = FunctionCallFactory.createExpression(functionName, concatArgs);
for (int i = 2; i < argsSize; i++) {
concatArgs = Lists.newArrayList();
concatArgs.add(first);
concatArgs.add(args.get(i));
first = FunctionCallFactory.createExpression(functionName, concatArgs);
}
return first;
}
} else if (functionName.equals("length")) {
if (argsSize == 2) {
// Second argument should always be a literal specifying the encoding format
assert args.get(1) instanceof ValueExpressions.QuotedString;
String encodingType = ((ValueExpressions.QuotedString) args.get(1)).value;
functionName += encodingType.substring(0, 1).toUpperCase() + encodingType.substring(1).toLowerCase();
return FunctionCallFactory.createExpression(functionName, args.subList(0, 1));
}
} else if ((functionName.equals("convert_from") || functionName.equals("convert_to"))
&& args.get(1) instanceof QuotedString) {
return FunctionCallFactory.createConvert(functionName, ((QuotedString)args.get(1)).value, args.get(0), ExpressionPosition.UNKNOWN);
} else if ((functionName.equalsIgnoreCase("rpad")) || functionName.equalsIgnoreCase("lpad")) {
// If we have only two arguments for rpad/lpad append a default QuotedExpression as an argument which will be used to pad the string
if (argsSize == 2) {
String spaceFill = " ";
LogicalExpression fill = ValueExpressions.getChar(spaceFill);
args.add(fill);
}
}
return FunctionCallFactory.createExpression(functionName, args);
}
@Override
public LogicalExpression visitLiteral(RexLiteral literal) {
switch(literal.getType().getSqlTypeName()){
case BIGINT:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.BIGINT);
}
long l = (((BigDecimal) literal.getValue()).setScale(0, BigDecimal.ROUND_HALF_UP)).longValue();
return ValueExpressions.getBigInt(l);
case BOOLEAN:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.BIT);
}
return ValueExpressions.getBit(((Boolean) literal.getValue()));
case CHAR:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.VARCHAR);
}
return ValueExpressions.getChar(((NlsString)literal.getValue()).getValue());
case DOUBLE:
if (isLiteralNull(literal)){
return createNullExpr(MinorType.FLOAT8);
}
double d = ((BigDecimal) literal.getValue()).doubleValue();
return ValueExpressions.getFloat8(d);
case FLOAT:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.FLOAT4);
}
float f = ((BigDecimal) literal.getValue()).floatValue();
return ValueExpressions.getFloat4(f);
case INTEGER:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.INT);
}
int a = (((BigDecimal) literal.getValue()).setScale(0, BigDecimal.ROUND_HALF_UP)).intValue();
return ValueExpressions.getInt(a);
case DECIMAL:
/* TODO: Enable using Decimal literals once we have more functions implemented for Decimal
* For now continue using Double instead of decimals
int precision = ((BigDecimal) literal.getValue()).precision();
if (precision <= 9) {
return ValueExpressions.getDecimal9((BigDecimal)literal.getValue());
} else if (precision <= 18) {
return ValueExpressions.getDecimal18((BigDecimal)literal.getValue());
} else if (precision <= 28) {
return ValueExpressions.getDecimal28((BigDecimal)literal.getValue());
} else if (precision <= 38) {
return ValueExpressions.getDecimal38((BigDecimal)literal.getValue());
} */
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.FLOAT8);
}
double dbl = ((BigDecimal) literal.getValue()).doubleValue();
logger.warn("Converting exact decimal into approximate decimal. Should be fixed once decimal is implemented.");
return ValueExpressions.getFloat8(dbl);
case VARCHAR:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.VARCHAR);
}
return ValueExpressions.getChar(((NlsString)literal.getValue()).getValue());
case SYMBOL:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.VARCHAR);
}
return ValueExpressions.getChar(literal.getValue().toString());
case DATE:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.DATE);
}
return (ValueExpressions.getDate((GregorianCalendar)literal.getValue()));
case TIME:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.TIME);
}
return (ValueExpressions.getTime((GregorianCalendar)literal.getValue()));
case TIMESTAMP:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.TIMESTAMP);
}
return (ValueExpressions.getTimeStamp((GregorianCalendar) literal.getValue()));
case INTERVAL_YEAR_MONTH:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.INTERVALYEAR);
}
return (ValueExpressions.getIntervalYear(((BigDecimal) (literal.getValue())).intValue()));
case INTERVAL_DAY_TIME:
if (isLiteralNull(literal)) {
return createNullExpr(MinorType.INTERVALDAY);
}
return (ValueExpressions.getIntervalDay(((BigDecimal) (literal.getValue())).longValue()));
case NULL:
return NullExpression.INSTANCE;
case ANY:
if (isLiteralNull(literal)) {
return NullExpression.INSTANCE;
}
default:
throw new UnsupportedOperationException(String.format("Unable to convert the value of %s and type %s to a Drill constant expression.", literal, literal.getType().getSqlTypeName()));
}
}
}
private static final TypedNullConstant createNullExpr(MinorType type) {
return new TypedNullConstant(Types.optional(type));
}
public static boolean isLiteralNull(RexLiteral literal) {
return literal.getTypeName().getName().equals("NULL");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy