All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.jet.sql.impl.opt.OptUtils Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Copyright 2021 Hazelcast Inc.
 *
 * Licensed under the Hazelcast Community License (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://hazelcast.com/hazelcast-community-license
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.hazelcast.jet.sql.impl.opt;

import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.jet.datamodel.Tuple2;
import com.hazelcast.jet.impl.util.Util;
import com.hazelcast.jet.sql.impl.connector.SqlConnectorUtil;
import com.hazelcast.jet.sql.impl.opt.distribution.DistributionTrait;
import com.hazelcast.jet.sql.impl.opt.metadata.Boundedness;
import com.hazelcast.jet.sql.impl.opt.metadata.HazelcastRelMetadataQuery;
import com.hazelcast.jet.sql.impl.opt.physical.visitor.RexToExpressionVisitor;
import com.hazelcast.jet.sql.impl.schema.HazelcastRelOptTable;
import com.hazelcast.jet.sql.impl.schema.HazelcastTable;
import com.hazelcast.jet.sql.impl.schema.JetTable;
import com.hazelcast.jet.sql.impl.validate.types.HazelcastJsonType;
import com.hazelcast.jet.sql.impl.validate.types.HazelcastObjectType;
import com.hazelcast.jet.sql.impl.validate.types.HazelcastObjectTypeReference;
import com.hazelcast.jet.sql.impl.validate.types.HazelcastTypeUtils;
import com.hazelcast.sql.impl.QueryParameterMetadata;
import com.hazelcast.sql.impl.expression.Expression;
import com.hazelcast.sql.impl.plan.node.PlanNodeFieldTypeProvider;
import com.hazelcast.sql.impl.plan.node.PlanNodeSchema;
import com.hazelcast.sql.impl.schema.Table;
import com.hazelcast.sql.impl.schema.TableField;
import com.hazelcast.sql.impl.type.QueryDataType;
import com.hazelcast.org.apache.calcite.plan.ConventionTraitDef;
import com.hazelcast.org.apache.calcite.plan.HazelcastRelOptCluster;
import com.hazelcast.org.apache.calcite.plan.RelOptCluster;
import com.hazelcast.org.apache.calcite.plan.RelOptRule;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleOperand;
import com.hazelcast.org.apache.calcite.plan.RelOptTable;
import com.hazelcast.org.apache.calcite.plan.RelTrait;
import com.hazelcast.org.apache.calcite.plan.RelTraitSet;
import com.hazelcast.org.apache.calcite.plan.volcano.HazelcastRelSubsetUtil;
import com.hazelcast.org.apache.calcite.plan.volcano.RelSubset;
import com.hazelcast.org.apache.calcite.prepare.RelOptTableImpl;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.Calc;
import com.hazelcast.org.apache.calcite.rel.core.TableScan;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalTableScan;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeFactory;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexCall;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexProgram;
import com.hazelcast.org.apache.calcite.rex.RexShuttle;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.rex.RexVisitor;
import com.hazelcast.org.apache.calcite.rex.RexVisitorImpl;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeName;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;

import static com.hazelcast.jet.impl.util.Util.arrayIndexOf;
import static com.hazelcast.jet.sql.impl.opt.Conventions.LOGICAL;
import static com.hazelcast.jet.sql.impl.opt.Conventions.PHYSICAL;

/**
 * Static utility classes for rules.
 */
public final class OptUtils {

    private OptUtils() {
    }

    /**
     * Convert the given trait set to logical convention.
     *
     * @param traitSet Original trait set.
     * @return New trait set with logical convention.
     */
    public static RelTraitSet toLogicalConvention(RelTraitSet traitSet) {
        return traitPlus(traitSet, LOGICAL);
    }

    /**
     * Convert the given input into logical input.
     *
     * @param rel Original input.
     * @return Logical input.
     */
    public static RelNode toLogicalInput(RelNode rel) {
        return RelOptRule.convert(rel, toLogicalConvention(rel.getTraitSet()));
    }

    /**
     * Convert the given trait set to physical convention.
     *
     * @param traitSet Original trait set.
     * @return New trait set with physical convention and provided distribution.
     */
    public static RelTraitSet toPhysicalConvention(RelTraitSet traitSet) {
        return traitPlus(traitSet, PHYSICAL);
    }

    /**
     * Convert the given input into physical input.
     *
     * @param rel Original input.
     * @return Logical input.
     */
    public static RelNode toPhysicalInput(RelNode rel) {
        return RelOptRule.convert(rel, toPhysicalConvention(rel.getTraitSet()));
    }

    /**
     * Add a single trait to the trait set.
     *
     * @param traitSet Original trait set.
     * @param trait    Trait to add.
     * @return Resulting trait set.
     */
    public static RelTraitSet traitPlus(RelTraitSet traitSet, RelTrait trait) {
        return traitSet.plus(trait).simplify();
    }

    public static LogicalTableScan createLogicalScan(
            RelOptCluster cluster,
            HazelcastTable hazelcastTable
    ) {
        JetTable table = hazelcastTable.getTarget();

        HazelcastRelOptTable relTable = createRelTable(
                table.getQualifiedName(),
                hazelcastTable,
                cluster.getTypeFactory()
        );
        return LogicalTableScan.create(cluster, relTable, ImmutableList.of());
    }

    public static LogicalTableScan createLogicalScan(
            TableScan originalScan,
            HazelcastTable newHazelcastTable
    ) {
        HazelcastRelOptTable originalRelTable = (HazelcastRelOptTable) originalScan.getTable();

        HazelcastRelOptTable newTable = createRelTable(
                originalRelTable,
                newHazelcastTable,
                originalScan.getCluster().getTypeFactory()
        );

        return LogicalTableScan.create(
                originalScan.getCluster(),
                newTable,
                originalScan.getHints()
        );
    }

    public static HazelcastRelOptTable createRelTable(
            List names,
            HazelcastTable hazelcastTable,
            RelDataTypeFactory typeFactory
    ) {
        RelDataType rowType = hazelcastTable.getRowType(typeFactory);

        RelOptTableImpl relTable = RelOptTableImpl.create(
                null,
                rowType,
                names,
                hazelcastTable,
                (com.hazelcast.org.apache.calcite.linq4j.tree.Expression) null
        );
        return new HazelcastRelOptTable(relTable);
    }

    public static HazelcastRelOptTable createRelTable(
            HazelcastRelOptTable originalRelTable,
            HazelcastTable newHazelcastTable,
            RelDataTypeFactory typeFactory
    ) {
        RelOptTableImpl newTable = RelOptTableImpl.create(
                originalRelTable.getRelOptSchema(),
                newHazelcastTable.getRowType(typeFactory),
                originalRelTable.getDelegate().getQualifiedName(),
                newHazelcastTable,
                (com.hazelcast.org.apache.calcite.linq4j.tree.Expression) null
        );

        return new HazelcastRelOptTable(newTable);
    }

    /**
     * Finds a set for the given RelNode, and return subsets that have the
     * physical trait. Every returned input is guaranteed to have a unique trait
     * set.
     *
     * @return Physical rels.
     */
    public static Collection extractPhysicalRelsFromSubset(RelNode input) {
        return extractRelsFromSubset(input, OptUtils::isPhysical);
    }

    private static boolean isPhysical(RelNode rel) {
        return rel.getTraitSet().getTrait(ConventionTraitDef.INSTANCE).equals(Conventions.PHYSICAL);
    }

    /**
     * Finds a set for the given RelNode, and return subsets that have the
     * logical trait. Every returned input is guaranteed to have a unique trait
     * set.
     *
     * @return Logical rels.
     */
    public static Collection extractLogicalRelsFromSubset(RelNode input) {
        return extractRelsFromSubset(input, OptUtils::isLogical);
    }

    private static boolean isLogical(RelNode rel) {
        return rel.getTraitSet().getTrait(ConventionTraitDef.INSTANCE).equals(Conventions.LOGICAL);
    }

    private static Collection extractRelsFromSubset(RelNode input, Predicate predicate) {
        Set traitSets = new HashSet<>();

        Set res = Collections.newSetFromMap(new IdentityHashMap<>());

        for (RelNode rel : HazelcastRelSubsetUtil.getSubsets(input)) {
            if (!predicate.test(rel)) {
                continue;
            }

            if (traitSets.add(rel.getTraitSet())) {
                res.add(rel);
            }
        }

        return res;
    }

    public static boolean isBounded(RelNode rel) {
        return metadataQuery(rel).extractBoundedness(rel) == Boundedness.BOUNDED;
    }

    public static boolean isUnbounded(RelNode rel) {
        return metadataQuery(rel).extractBoundedness(rel) == Boundedness.UNBOUNDED;
    }

    public static HazelcastRelMetadataQuery metadataQuery(RelNode rel) {
        return HazelcastRelMetadataQuery.reuseOrCreate(rel.getCluster().getMetadataQuery());
    }

    public static HazelcastRelOptCluster getCluster(RelNode rel) {
        assert rel.getCluster() instanceof HazelcastRelOptCluster;

        return (HazelcastRelOptCluster) rel.getCluster();
    }

    public static DistributionTrait getDistribution(RelNode rel) {
        return rel.getTraitSet().getTrait(getCluster(rel).getDistributionTraitDef());
    }

    /**
     * If the {@code node} is a {@link RelSubset}, finds the subset matching
     * the {@code operandPredicate}.
     * If multiple matches are found, it throws. If no match is found, it returns null.
     * 

* If the {@code node} isn't a {@code RelSubset} and it matches the * predicate, the {@code node} is returned. If it doesn't match, * {@code null} is returned. */ @SuppressWarnings("unchecked") @Nullable public static T findMatchingRel(RelNode node, RelOptRuleOperand operandPredicate) { if (node instanceof RelSubset) { RelNode res = null; for (RelNode rel : ((RelSubset) node).getRels()) { if (operandPredicate.matches(rel)) { if (res != null) { throw new RuntimeException("multiple matches found"); } res = rel; } } return (T) res; } else if (operandPredicate.matches(node)) { return (T) node; } return null; } public static PlanNodeSchema schema(RelDataType rowType) { return new PlanNodeSchema(extractFieldTypes(rowType)); } public static PlanNodeSchema schema(RelOptTable relTable) { Table table = relTable.unwrap(HazelcastTable.class).getTarget(); return schema(table); } public static PlanNodeSchema schema(Table table) { List fieldTypes = new ArrayList<>(); for (TableField field : table.getFields()) { fieldTypes.add(field.getType()); } return new PlanNodeSchema(fieldTypes); } public static RexVisitor> createRexToExpressionVisitor( PlanNodeFieldTypeProvider schema, QueryParameterMetadata parameterMetadata ) { return new RexToExpressionVisitor(schema, parameterMetadata); } /** * Converts a {@link TableField} to {@link RelDataType}. */ public static RelDataType convert(TableField field, RelDataTypeFactory typeFactory) { QueryDataType fieldType = field.getType(); SqlTypeName sqlTypeName = HazelcastTypeUtils.toCalciteType(fieldType); if (sqlTypeName == null) { throw new IllegalStateException("Unsupported type family: " + fieldType + ", getSqlTypeName should never return null."); } if (sqlTypeName == SqlTypeName.OTHER) { return convertOtherType(fieldType); } else if (fieldType.isCustomType()) { return convertCustomType(fieldType, typeFactory); } else { RelDataType relType = typeFactory.createSqlType(sqlTypeName); return typeFactory.createTypeWithNullability(relType, true); } } private static RelDataType convertCustomType(QueryDataType fieldType, RelDataTypeFactory typeFactory) { final Map dataTypeMap = new HashMap<>(); convertCustomTypeRecursively(fieldType, typeFactory, dataTypeMap); return dataTypeMap.get(fieldType.getObjectTypeName()); } private static void convertCustomTypeRecursively( QueryDataType type, RelDataTypeFactory typeFactory, Map typeMap ) { if (typeMap.get(type.getObjectTypeName()) != null) { return; } final List fields = new ArrayList<>(); final HazelcastObjectTypeReference typeRef = new HazelcastObjectTypeReference(); typeMap.put(type.getObjectTypeName(), typeRef); for (int i = 0; i < type.getObjectFields().size(); i++) { final String fieldName = type.getObjectFields().get(i).getName(); final QueryDataType fieldType = type.getObjectFields().get(i).getDataType(); RelDataType fieldRelDataType; if (fieldType.isCustomType()) { fieldRelDataType = typeMap.get(fieldType.getObjectTypeName()); if (fieldRelDataType == null) { convertCustomTypeRecursively(fieldType, typeFactory, typeMap); fieldRelDataType = typeMap.get(fieldType.getObjectTypeName()); } } else { fieldRelDataType = typeFactory.createTypeWithNullability( typeFactory.createSqlType(HazelcastTypeUtils.toCalciteType(fieldType)), true ); } fields.add(new HazelcastObjectType.Field(fieldName, i, fieldRelDataType)); } typeRef.setOriginal(new HazelcastObjectType(type.getObjectTypeName(), fields)); } private static RelDataType convertOtherType(QueryDataType fieldType) { switch (fieldType.getTypeFamily()) { case JSON: return HazelcastJsonType.create(true); default: throw new IllegalStateException("Unexpected type family: " + fieldType); } } private static List extractFieldTypes(RelDataType rowType) { return Util.toList(rowType.getFieldList(), f -> HazelcastTypeUtils.toHazelcastType(f.getType())); } public static boolean requiresJob(RelNode rel) { return ((HazelcastRelOptCluster) rel.getCluster()).requiresJob(); } public static boolean hasTableType(RelNode rel, Class tableClass) { if (rel.getTable() == null) { return false; } HazelcastTable table = rel.getTable().unwrap(HazelcastTable.class); return table != null && tableClass.isAssignableFrom(table.getTarget().getClass()); } public static HazelcastTable extractHazelcastTable(RelNode rel) { HazelcastTable table = rel.getTable().unwrap(HazelcastTable.class); assert table != null; return table; } @SuppressWarnings("checkstyle:AvoidNestedBlocks") public static RexNode extractKeyConstantExpression(RelOptTable relTable, RexBuilder rexBuilder) { HazelcastTable table = relTable.unwrap(HazelcastTable.class); RexNode filter = table.getFilter(); if (filter == null) { return null; } int keyIndex = findKeyIndex(table.getTarget()); switch (filter.getKind()) { // WHERE __key = true, calcite simplifies to just `WHERE __key` case INPUT_REF: { return ((RexInputRef) filter).getIndex() == keyIndex ? rexBuilder.makeLiteral(true) : null; } // WHERE __key = false, calcite simplifies to `WHERE NOT __key` case NOT: { RexNode operand = ((RexCall) filter).getOperands().get(0); return operand.getKind() == SqlKind.INPUT_REF && ((RexInputRef) operand).getIndex() == keyIndex ? rexBuilder.makeLiteral(false) : null; } // __key = ... case EQUALS: { Tuple2 constantExpressionByIndex = extractConstantExpression((RexCall) filter); //noinspection ConstantConditions return constantExpressionByIndex != null && constantExpressionByIndex.getKey() == keyIndex ? constantExpressionByIndex.getValue() : null; } default: return null; } } private static int findKeyIndex(Table table) { List primaryKey = SqlConnectorUtil.getJetSqlConnector(table).getPrimaryKey(table); // just single field keys supported at the moment assert primaryKey.size() == 1; int keyIndex = table.getFieldIndex(primaryKey.get(0)); assert keyIndex > -1; return keyIndex; } private static Tuple2 extractConstantExpression(RexCall condition) { Tuple2 constantExpression = extractConstantExpression(condition, 0); return constantExpression != null ? constantExpression : extractConstantExpression(condition, 1); } private static Tuple2 extractConstantExpression(RexCall condition, int i) { RexNode firstOperand = condition.getOperands().get(i); if (firstOperand.getKind() == SqlKind.INPUT_REF) { int index = ((RexInputRef) firstOperand).getIndex(); RexNode secondOperand = condition.getOperands().get(1 - i); if (RexUtil.isConstant(secondOperand)) { return Tuple2.tuple2(index, secondOperand); } } return null; } /** * Return true if the `expression` contains any input reference to a field * with index in `indexes`. */ public static boolean hasInputRef(RexNode expression, int... indexes) { boolean[] res = {false}; expression.accept(new RexVisitorImpl(true) { @Override public Void visitInputRef(RexInputRef inputRef) { if (arrayIndexOf(inputRef.getIndex(), indexes) >= 0) { res[0] = true; } return null; } }); return res[0]; } /** * Inlines `inlinedExpressions` into `expr` and returns the modified expression. *

* Example: * {@code * inlinedExpressions: [UPPER($1), LOWER($0)] * expr: $1 || $0 * result: LOWER($0) || UPPER($1) * } */ @SuppressWarnings("checkstyle:AnonInnerLength") public static RexNode inlineExpression(List inlinedExpressions, RexNode expr) { return expr.accept(new RexShuttle() { @Override public RexNode visitInputRef(RexInputRef inputRef) { return inlinedExpressions.get(inputRef.getIndex()); } }); } /** * Same as {@link #inlineExpression(List, RexNode)}, but applied to all * expressions in {@code exprs}. */ public static List inlineExpressions(List inlinedExpressions, List exprs) { List res = new ArrayList<>(exprs.size()); for (RexNode expr : exprs) { res.add(inlineExpression(inlinedExpressions, expr)); } return res; } /** * Return the index at which a {@link Calc} program projects the input's * field at index {@code inputFieldIndex}. If the program doesn't project * that field directly, returns -1. If it projects it multiple times, * returns the first occurrence. *

* For example, if the input has fields `a, b, c` and the projection is `c, * a, b+1`, then: *

    *
  • getTargetField(0) = 1 *
  • getTargetField(1) = -1 *
  • getTargetField(2) = 0 *
* * The method is named analogously to {@link * RexProgram#getSourceField(int)}, which finds the opposite mapping. * * @param calcProgram The calc program (the projection) * @param inputFieldIndex The index of the input field * @return The position of the input field in the output, or -1, if it's not in the output. */ public static int getTargetField(RexProgram calcProgram, int inputFieldIndex) { for (int i = 0; i < calcProgram.getProjectList().size(); i++) { int expressionIndex = calcProgram.getProjectList().get(i).getIndex(); RexNode expr = calcProgram.getExprList().get(expressionIndex); if (expr instanceof RexInputRef && ((RexInputRef) expr).getIndex() == inputFieldIndex) { return i; } } return -1; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy