com.facebook.presto.verifier.rewrite.QueryRewriter Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.verifier.rewrite;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.common.type.TypeSignatureParameter;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.tree.AllColumns;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CreateTable;
import com.facebook.presto.sql.tree.CreateTableAsSelect;
import com.facebook.presto.sql.tree.CreateView;
import com.facebook.presto.sql.tree.DropTable;
import com.facebook.presto.sql.tree.DropView;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.Insert;
import com.facebook.presto.sql.tree.LikeClause;
import com.facebook.presto.sql.tree.Property;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.Query;
import com.facebook.presto.sql.tree.QuerySpecification;
import com.facebook.presto.sql.tree.Select;
import com.facebook.presto.sql.tree.SelectItem;
import com.facebook.presto.sql.tree.ShowCreate;
import com.facebook.presto.sql.tree.SingleColumn;
import com.facebook.presto.sql.tree.Statement;
import com.facebook.presto.verifier.framework.ClusterType;
import com.facebook.presto.verifier.framework.QueryException;
import com.facebook.presto.verifier.framework.QueryObjectBundle;
import com.facebook.presto.verifier.prestoaction.PrestoAction;
import com.facebook.presto.verifier.prestoaction.PrestoAction.ResultSetConverter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.intellij.lang.annotations.Language;
import java.sql.ResultSetMetaData;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.RowType.Field;
import static com.facebook.presto.common.type.StandardTypes.MAP;
import static com.facebook.presto.common.type.TimeType.TIME;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.tree.LikeClause.PropertiesOption.INCLUDING;
import static com.facebook.presto.sql.tree.ShowCreate.Type.VIEW;
import static com.facebook.presto.verifier.framework.CreateViewVerification.SHOW_CREATE_VIEW_CONVERTER;
import static com.facebook.presto.verifier.framework.QueryStage.REWRITE;
import static com.facebook.presto.verifier.framework.VerifierUtil.PARSING_OPTIONS;
import static com.facebook.presto.verifier.framework.VerifierUtil.getColumnNames;
import static com.facebook.presto.verifier.framework.VerifierUtil.getColumnTypes;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Map.Entry;
import static java.util.Objects.requireNonNull;
import static java.util.UUID.randomUUID;
public class QueryRewriter
{
private final SqlParser sqlParser;
private final TypeManager typeManager;
private final PrestoAction prestoAction;
private final Map prefixes;
private final Map> tableProperties;
private final Optional functionCallRewriter;
public QueryRewriter(
SqlParser sqlParser,
TypeManager typeManager,
PrestoAction prestoAction,
Map tablePrefixes,
Map> tableProperties)
{
this(sqlParser, typeManager, prestoAction, tablePrefixes, tableProperties, Optional.empty());
}
public QueryRewriter(
SqlParser sqlParser,
TypeManager typeManager,
PrestoAction prestoAction,
Map tablePrefixes,
Map> tableProperties,
Optional nonDeterministicFunctionSubstitutes)
{
this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
this.typeManager = requireNonNull(typeManager, "typeManager is null");
this.prestoAction = requireNonNull(prestoAction, "prestoAction is null");
this.prefixes = ImmutableMap.copyOf(tablePrefixes);
this.tableProperties = ImmutableMap.copyOf(tableProperties);
this.functionCallRewriter =
requireNonNull(nonDeterministicFunctionSubstitutes, "nonDeterministicFunctionSubstitutes is null").map(functionSubstitutes -> FunctionCallRewriter.getInstance(functionSubstitutes));
}
public QueryObjectBundle rewriteQuery(@Language("SQL") String query, ClusterType clusterType)
{
checkState(prefixes.containsKey(clusterType), "Unsupported cluster type: %s", clusterType);
Statement statement = sqlParser.createStatement(query, PARSING_OPTIONS);
QualifiedName prefix = prefixes.get(clusterType);
List properties = tableProperties.get(clusterType);
if (statement instanceof CreateTableAsSelect) {
CreateTableAsSelect createTableAsSelect = (CreateTableAsSelect) statement;
QualifiedName temporaryTableName = generateTemporaryName(Optional.of(createTableAsSelect.getName()), prefix);
Query createQuery = createTableAsSelect.getQuery();
if (functionCallRewriter.isPresent()) {
createQuery = (Query) functionCallRewriter.get().rewrite(createQuery);
}
return new QueryObjectBundle(
temporaryTableName,
ImmutableList.of(),
new CreateTableAsSelect(
temporaryTableName,
createQuery,
createTableAsSelect.isNotExists(),
applyPropertyOverride(createTableAsSelect.getProperties(), properties),
createTableAsSelect.isWithData(),
createTableAsSelect.getColumnAliases(),
createTableAsSelect.getComment()),
ImmutableList.of(new DropTable(temporaryTableName, true)),
clusterType);
}
if (statement instanceof Insert) {
Insert insert = (Insert) statement;
QualifiedName originalTableName = insert.getTarget();
QualifiedName temporaryTableName = generateTemporaryName(Optional.of(originalTableName), prefix);
Query insertQuery = insert.getQuery();
if (functionCallRewriter.isPresent()) {
insertQuery = (Query) functionCallRewriter.get().rewrite(insertQuery);
}
return new QueryObjectBundle(
temporaryTableName,
ImmutableList.of(
new CreateTable(
temporaryTableName,
ImmutableList.of(new LikeClause(originalTableName, Optional.of(INCLUDING))),
false,
properties,
Optional.empty())),
new Insert(
temporaryTableName,
insert.getColumns(),
insertQuery),
ImmutableList.of(new DropTable(temporaryTableName, true)),
clusterType);
}
if (statement instanceof Query) {
QualifiedName temporaryTableName = generateTemporaryName(Optional.empty(), prefix);
Query queryBody = (Query) statement;
if (functionCallRewriter.isPresent()) {
queryBody = (Query) functionCallRewriter.get().rewrite(queryBody);
}
ResultSetMetaData metadata = getResultMetadata(queryBody);
List columnAliases = generateStorageColumnAliases(metadata);
queryBody = rewriteNonStorableColumns(queryBody, metadata);
return new QueryObjectBundle(
temporaryTableName,
ImmutableList.of(),
new CreateTableAsSelect(
temporaryTableName,
queryBody,
false,
properties,
true,
Optional.of(columnAliases),
Optional.empty()),
ImmutableList.of(new DropTable(temporaryTableName, true)),
clusterType);
}
if (statement instanceof CreateView) {
CreateView createView = (CreateView) statement;
QualifiedName temporaryViewName = generateTemporaryName(Optional.empty(), prefix);
ImmutableList.Builder setupQueries = ImmutableList.builder();
// Check to see if there is an existing view with the specified view name.
// If view exists, create a temporary view that are has the same definition as the existing view.
// Otherwise, do not pre-create temporary view.
try {
String createExistingViewQuery = getOnlyElement(prestoAction.execute(
new ShowCreate(VIEW, createView.getName()),
REWRITE,
SHOW_CREATE_VIEW_CONVERTER).getResults());
CreateView createExistingView = (CreateView) sqlParser.createStatement(createExistingViewQuery, PARSING_OPTIONS);
setupQueries.add(new CreateView(
temporaryViewName,
createExistingView.getQuery(),
false,
createExistingView.getSecurity()));
}
catch (QueryException e) {
// no-op
}
return new QueryObjectBundle(
temporaryViewName,
setupQueries.build(),
new CreateView(
temporaryViewName,
createView.getQuery(),
createView.isReplace(),
createView.getSecurity()),
ImmutableList.of(new DropView(temporaryViewName, true)),
clusterType);
}
if (statement instanceof CreateTable) {
CreateTable createTable = (CreateTable) statement;
QualifiedName temporaryTableName = generateTemporaryName(Optional.empty(), prefix);
return new QueryObjectBundle(
temporaryTableName,
ImmutableList.of(),
new CreateTable(
temporaryTableName,
createTable.getElements(),
createTable.isNotExists(),
applyPropertyOverride(createTable.getProperties(), properties),
createTable.getComment()),
ImmutableList.of(new DropTable(temporaryTableName, true)),
clusterType);
}
throw new IllegalStateException(format("Unsupported query type: %s", statement.getClass()));
}
private QualifiedName generateTemporaryName(Optional originalName, QualifiedName prefix)
{
List parts = new ArrayList<>();
int originalSize = originalName.map(QualifiedName::getOriginalParts).map(List::size).orElse(0);
int prefixSize = prefix.getOriginalParts().size();
if (originalName.isPresent() && originalSize > prefixSize) {
parts.addAll(originalName.get().getOriginalParts().subList(0, originalSize - prefixSize));
}
parts.addAll(prefix.getOriginalParts());
parts.set(parts.size() - 1, prefix.getSuffix() + "_" + randomUUID().toString().replace("-", ""));
return QualifiedName.of(parts);
}
private List generateStorageColumnAliases(ResultSetMetaData metadata)
{
ImmutableList.Builder aliases = ImmutableList.builder();
Set usedAliases = new HashSet<>();
for (String columnName : getColumnNames(metadata)) {
columnName = sanitizeColumnName(columnName);
String alias = columnName;
int postfix = 1;
while (usedAliases.contains(alias)) {
alias = format("%s__%s", columnName, postfix);
postfix++;
}
aliases.add(new Identifier(alias, true));
usedAliases.add(alias);
}
return aliases.build();
}
private ResultSetMetaData getResultMetadata(Query query)
{
Query zeroRowQuery;
if (query.getQueryBody() instanceof QuerySpecification) {
QuerySpecification querySpecification = (QuerySpecification) query.getQueryBody();
zeroRowQuery = new Query(
query.getWith(),
new QuerySpecification(
querySpecification.getSelect(),
querySpecification.getFrom(),
querySpecification.getWhere(),
querySpecification.getGroupBy(),
querySpecification.getHaving(),
querySpecification.getOrderBy(),
querySpecification.getOffset(),
Optional.of("0")),
Optional.empty(),
Optional.empty(),
Optional.empty());
}
else {
zeroRowQuery = new Query(query.getWith(), query.getQueryBody(), Optional.empty(), Optional.empty(), Optional.of("0"));
}
return prestoAction.execute(zeroRowQuery, REWRITE, ResultSetConverter.DEFAULT).getMetadata();
}
private Query rewriteNonStorableColumns(Query query, ResultSetMetaData metadata)
{
// Skip if all columns are storable
List columnTypes = getColumnTypes(typeManager, metadata);
if (columnTypes.stream().noneMatch(type -> getColumnTypeRewrite(type).isPresent())) {
return query;
}
// Cannot handle SELECT query with top-level SetOperation
if (!(query.getQueryBody() instanceof QuerySpecification)) {
return query;
}
QuerySpecification querySpecification = (QuerySpecification) query.getQueryBody();
List selectItems = querySpecification.getSelect().getSelectItems();
// Cannot handle SELECT *
if (selectItems.stream().anyMatch(AllColumns.class::isInstance)) {
return query;
}
List newItems = new ArrayList<>();
checkState(selectItems.size() == columnTypes.size(), "SelectItem count (%s) mismatches column count (%s)", selectItems.size(), columnTypes.size());
for (int i = 0; i < selectItems.size(); i++) {
SingleColumn singleColumn = (SingleColumn) selectItems.get(i);
Optional columnTypeRewrite = getColumnTypeRewrite(columnTypes.get(i));
if (columnTypeRewrite.isPresent()) {
newItems.add(new SingleColumn(new Cast(singleColumn.getExpression(), columnTypeRewrite.get().getTypeSignature().toString()), singleColumn.getAlias()));
}
else {
newItems.add(singleColumn);
}
}
return new Query(
query.getWith(),
new QuerySpecification(
new Select(querySpecification.getSelect().isDistinct(), newItems),
querySpecification.getFrom(),
querySpecification.getWhere(),
querySpecification.getGroupBy(),
querySpecification.getHaving(),
querySpecification.getOrderBy(),
Optional.empty(),
querySpecification.getLimit()),
query.getOrderBy(),
Optional.empty(),
query.getLimit());
}
private Optional getColumnTypeRewrite(Type type)
{
if (type.equals(DATE) || type.equals(TIME)) {
return Optional.of(TIMESTAMP);
}
if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) {
return Optional.of(VARCHAR);
}
if (type.equals(UNKNOWN)) {
return Optional.of(BIGINT);
}
if (type instanceof DecimalType) {
return Optional.of(DOUBLE);
}
if (type instanceof ArrayType) {
return getColumnTypeRewrite(((ArrayType) type).getElementType()).map(ArrayType::new);
}
if (type instanceof MapType) {
Type keyType = ((MapType) type).getKeyType();
Type valueType = ((MapType) type).getValueType();
Optional keyTypeRewrite = getColumnTypeRewrite(keyType);
Optional valueTypeRewrite = getColumnTypeRewrite(valueType);
if (keyTypeRewrite.isPresent() || valueTypeRewrite.isPresent()) {
return Optional.of(typeManager.getType(new TypeSignature(
MAP,
TypeSignatureParameter.of(keyTypeRewrite.orElse(keyType).getTypeSignature()),
TypeSignatureParameter.of(valueTypeRewrite.orElse(valueType).getTypeSignature()))));
}
return Optional.empty();
}
if (type instanceof RowType) {
List fields = ((RowType) type).getFields();
List fieldsRewrite = new ArrayList<>();
boolean rewrite = false;
for (Field field : fields) {
Optional fieldTypeRewrite = getColumnTypeRewrite(field.getType());
rewrite = rewrite || fieldTypeRewrite.isPresent();
fieldsRewrite.add(new Field(field.getName(), fieldTypeRewrite.orElse(field.getType())));
}
return rewrite ? Optional.of(RowType.from(fieldsRewrite)) : Optional.empty();
}
return Optional.empty();
}
private static String sanitizeColumnName(String columnName)
{
return columnName.replaceAll("[^a-zA-Z0-9_]", "_").toLowerCase(ENGLISH);
}
private static List applyPropertyOverride(List properties, List overrides)
{
Map propertyMap = properties.stream()
.collect(toImmutableMap(property -> property.getName().getValueLowerCase(), Property::getValue));
Map overrideMap = overrides.stream()
.collect(toImmutableMap(property -> property.getName().getValueLowerCase(), Property::getValue));
return Stream.concat(propertyMap.entrySet().stream(), overrideMap.entrySet().stream())
.collect(Collectors.toMap(Entry::getKey, Entry::getValue, (original, override) -> override))
.entrySet()
.stream()
.map(entry -> new Property(new Identifier(entry.getKey()), entry.getValue()))
.collect(toImmutableList());
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy