io.prestosql.sql.analyzer.ExpressionTreeUtils Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.analyzer;
import com.google.common.collect.ImmutableList;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.Location;
import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor;
import io.prestosql.sql.tree.DereferenceExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.Identifier;
import io.prestosql.sql.tree.Node;
import io.prestosql.sql.tree.QualifiedName;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import static com.google.common.base.Predicates.alwaysTrue;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Streams.stream;
import static java.util.Objects.requireNonNull;
public final class ExpressionTreeUtils
{
private ExpressionTreeUtils() {}
static List extractAggregateFunctions(Iterable nodes, Metadata metadata)
{
return extractExpressions(nodes, FunctionCall.class, function -> isAggregation(function, metadata));
}
static List extractWindowFunctions(Iterable nodes)
{
return extractExpressions(nodes, FunctionCall.class, ExpressionTreeUtils::isWindowFunction);
}
public static List extractExpressions(
Iterable nodes,
Class clazz)
{
return extractExpressions(nodes, clazz, alwaysTrue());
}
private static boolean isAggregation(FunctionCall functionCall, Metadata metadata)
{
return ((metadata.isAggregationFunction(functionCall.getName()) || functionCall.getFilter().isPresent())
&& functionCall.getWindow().isEmpty())
|| functionCall.getOrderBy().isPresent();
}
private static boolean isWindowFunction(FunctionCall functionCall)
{
return functionCall.getWindow().isPresent();
}
private static List extractExpressions(
Iterable nodes,
Class clazz,
Predicate predicate)
{
requireNonNull(nodes, "nodes is null");
requireNonNull(clazz, "clazz is null");
requireNonNull(predicate, "predicate is null");
return stream(nodes)
.flatMap(node -> linearizeNodes(node).stream())
.filter(clazz::isInstance)
.map(clazz::cast)
.filter(predicate)
.collect(toImmutableList());
}
private static List linearizeNodes(Node node)
{
ImmutableList.Builder nodes = ImmutableList.builder();
new DefaultExpressionTraversalVisitor()
{
@Override
public Void process(Node node, Void context)
{
super.process(node, context);
nodes.add(node);
return null;
}
}.process(node, null);
return nodes.build();
}
public static Optional extractLocation(Node node)
{
return node.getLocation()
.map(location -> new Location(location.getLineNumber(), location.getColumnNumber()));
}
public static QualifiedName asQualifiedName(Expression expression)
{
QualifiedName name = null;
if (expression instanceof Identifier) {
name = QualifiedName.of(((Identifier) expression).getValue());
}
else if (expression instanceof DereferenceExpression) {
name = DereferenceExpression.getQualifiedName((DereferenceExpression) expression);
}
return name;
}
}