All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.jet.sql.impl.schema.JetDynamicTableFunction Maven / Gradle / Ivy

/*
 * Copyright 2021 Hazelcast Inc.
 *
 * Licensed under the Hazelcast Community License (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://hazelcast.com/hazelcast-community-license
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.hazelcast.jet.sql.impl.schema;

import com.hazelcast.jet.sql.impl.connector.SqlConnector;
import com.hazelcast.jet.sql.impl.validate.ValidationUtil;
import com.hazelcast.sql.impl.QueryException;
import com.hazelcast.jet.sql.impl.validate.HazelcastSqlValidator;
import com.hazelcast.sql.impl.schema.Table;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeComparability;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeFamily;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeField;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypePrecedenceList;
import com.hazelcast.org.apache.calcite.rel.type.StructKind;
import com.hazelcast.org.apache.calcite.sql.SqlCall;
import com.hazelcast.org.apache.calcite.sql.SqlCallBinding;
import com.hazelcast.org.apache.calcite.sql.SqlCollation;
import com.hazelcast.org.apache.calcite.sql.SqlDynamicParam;
import com.hazelcast.org.apache.calcite.sql.SqlIdentifier;
import com.hazelcast.org.apache.calcite.sql.SqlIntervalQualifier;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.SqlLiteral;
import com.hazelcast.org.apache.calcite.sql.SqlNode;
import com.hazelcast.org.apache.calcite.sql.SqlOperatorBinding;
import com.hazelcast.org.apache.calcite.sql.SqlUtil;
import com.hazelcast.org.apache.calcite.sql.type.SqlOperandTypeInference;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeName;
import com.hazelcast.org.apache.calcite.util.NlsString;

import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static com.hazelcast.org.apache.calcite.sql.type.SqlTypeName.MAP;
import static com.hazelcast.org.apache.calcite.sql.type.SqlTypeName.VARCHAR;

/**
 * A table function return type of which is NOT known upfront and is determined during validation phase.
 */
public abstract class JetDynamicTableFunction extends JetTableFunction {

    protected JetDynamicTableFunction(
            String name,
            List parameters,
            Function, Table> tableFn,
            SqlOperandTypeInference operandTypeInference,
            SqlConnector connector
    ) {
        super(
                name,
                parameters,
                binding -> inferReturnType(name, parameters, tableFn, binding),
                operandTypeInference,
                connector
        );

        assert parameters.stream()
                .map(JetTableFunctionParameter::type)
                .allMatch(type -> type == VARCHAR || type == MAP);
    }

    public final HazelcastTable toTable(RelDataType rowType) {
        return ((JetFunctionRelDataType) rowType).table();
    }

    private static RelDataType inferReturnType(
            String name,
            List parameters,
            Function, Table> tableFn,
            SqlOperatorBinding callBinding
    ) {
        List arguments = toArguments(name, parameters, callBinding);
        HazelcastTable table = new HazelcastTable(tableFn.apply(arguments), UnknownStatistic.INSTANCE);
        RelDataType rowType = table.getRowType(callBinding.getTypeFactory());
        return new JetFunctionRelDataType(table, rowType);
    }

    private static List toArguments(
            String functionName,
            List parameters,
            SqlOperatorBinding callBinding
    ) {
        SqlCallBinding binding = (SqlCallBinding) callBinding;
        SqlCall call = binding.getCall();
        HazelcastSqlValidator validator = (HazelcastSqlValidator) binding.getValidator();

        return ValidationUtil.hasAssignment(call)
                ? fromNamedArguments(functionName, parameters, call, validator)
                : fromPositionalArguments(functionName, parameters, call, validator);
    }

    private static List fromNamedArguments(
            String functionName,
            List parameters,
            SqlCall call,
            HazelcastSqlValidator validator
    ) {
        List arguments = new ArrayList<>(parameters.size());
        for (JetTableFunctionParameter parameter : parameters) {
            SqlNode operand = findOperandByName(parameter.name(), call);
            Object value = operand == null ? null : extractValue(functionName, parameter, operand, validator);
            arguments.add(value);
        }
        return arguments;
    }

    private static SqlNode findOperandByName(String name, SqlCall call) {
        for (int i = 0; i < call.operandCount(); i++) {
            SqlCall assignment = call.operand(i);
            SqlIdentifier id = assignment.operand(1);
            if (name.equals(id.getSimple())) {
                return assignment.operand(0);
            }
        }
        return null;
    }

    private static List fromPositionalArguments(
            String functionName,
            List parameters,
            SqlCall call,
            HazelcastSqlValidator validator
    ) {
        List arguments = new ArrayList<>(parameters.size());
        for (int i = 0; i < call.operandCount(); i++) {
            Object value = extractValue(functionName, parameters.get(i), call.operand(i), validator);
            arguments.add(value);
        }
        for (int i = call.operandCount(); i < parameters.size(); i++) {
            arguments.add(null);
        }
        return arguments;
    }

    private static Object extractValue(
            String functionName,
            JetTableFunctionParameter parameter,
            SqlNode operand,
            HazelcastSqlValidator validator
    ) {
        if (operand.getKind() == SqlKind.DEFAULT) {
            return null;
        }
        if (SqlUtil.isNullLiteral(operand, true)) {
            return null;
        }
        if (operand.getKind() == SqlKind.DYNAMIC_PARAM) {
            return validator.getArgumentAt(((SqlDynamicParam) operand).getIndex());
        }

        SqlTypeName parameterType = parameter.type();
        if (SqlUtil.isLiteral(operand) && parameterType == SqlTypeName.VARCHAR) {
            String value = extractStringValue(((SqlLiteral) operand));
            if (value != null) {
                return value;
            }
        } else if (operand.getKind() == SqlKind.MAP_VALUE_CONSTRUCTOR && parameterType == SqlTypeName.MAP) {
            return extractMapValue(functionName, parameter, (SqlCall) operand, validator);
        }
        throw QueryException.error("Invalid argument of a call to function " + functionName + " - #"
                + parameter.ordinal() + " (" + parameter.name() + "). Expected: " + parameterType
                + ", actual: "
                + (SqlUtil.isLiteral(operand) ? ((SqlLiteral) operand).getTypeName() : operand.getKind()));
    }

    private static String extractStringValue(SqlLiteral literal) {
        Object value = literal.getValue();
        return value instanceof NlsString ? ((NlsString) value).getValue() : null;
    }

    private static Map extractMapValue(
            String functionName,
            JetTableFunctionParameter parameter,
            SqlCall call,
            HazelcastSqlValidator validator
    ) {
        List operands = call.getOperandList();
        Map entries = new HashMap<>();
        for (int i = 0; i < operands.size(); i += 2) {
            String key = extractMapStringValue(functionName, parameter, operands.get(i), validator);
            String value = extractMapStringValue(functionName, parameter, operands.get(i + 1), validator);
            if (entries.putIfAbsent(key, value) != null) {
                throw QueryException.error(
                        "Duplicate entry in the MAP constructor in the call to function " + functionName + " - " +
                                "argument #" + parameter.ordinal() + " (" + parameter.name() + ")");
            }
        }
        return entries;
    }

    private static String extractMapStringValue(
            String functionName,
            JetTableFunctionParameter parameter,
            SqlNode node,
            HazelcastSqlValidator validator
    ) {
        if (node.getKind() == SqlKind.DYNAMIC_PARAM) {
            Object value = validator.getArgumentAt(((SqlDynamicParam) node).getIndex());
            if (value instanceof String) {
                return (String) value;
            }
        }
        if (SqlUtil.isLiteral(node)) {
            SqlLiteral literal = (SqlLiteral) node;
            Object value = literal.getValue();
            if (value instanceof NlsString) {
                return ((NlsString) value).getValue();
            }
        }
        throw QueryException.error(
                "All values in the MAP constructor of the call to function " + functionName + ", argument #"
                        + parameter.ordinal() + " (" + parameter.name() + ") must be VARCHAR literals. "
                        + "Actual argument is: "
                        + (SqlUtil.isLiteral(node) ? ((SqlLiteral) node).getTypeName() : node.getKind()));
    }


    /**
     * The only purpose of this class is to be able to pass the {@code
     * HazelcastTable} object to place where the function is used.
     */
    private static final class JetFunctionRelDataType implements RelDataType {

        private final HazelcastTable table;
        private final RelDataType delegate;

        private JetFunctionRelDataType(HazelcastTable table, RelDataType delegate) {
            this.delegate = delegate;
            this.table = table;
        }

        private HazelcastTable table() {
            return table;
        }

        @Override
        public boolean isStruct() {
            return delegate.isStruct();
        }

        @Override
        public List getFieldList() {
            return delegate.getFieldList();
        }

        @Override
        public List getFieldNames() {
            return delegate.getFieldNames();
        }

        @Override
        public int getFieldCount() {
            return delegate.getFieldCount();
        }

        @Override
        public StructKind getStructKind() {
            return delegate.getStructKind();
        }

        @Override
        public RelDataTypeField getField(String fieldName, boolean caseSensitive, boolean elideRecord) {
            return delegate.getField(fieldName, caseSensitive, elideRecord);
        }

        @Override
        public boolean isNullable() {
            return delegate.isNullable();
        }

        @Override
        public RelDataType getComponentType() {
            return delegate.getComponentType();
        }

        @Override
        public RelDataType getKeyType() {
            return delegate.getKeyType();
        }

        @Override
        public RelDataType getValueType() {
            return delegate.getValueType();
        }

        @Override
        public Charset getCharset() {
            return delegate.getCharset();
        }

        @Override
        public SqlCollation getCollation() {
            return delegate.getCollation();
        }

        @Override
        public SqlIntervalQualifier getIntervalQualifier() {
            return delegate.getIntervalQualifier();
        }

        @Override
        public int getPrecision() {
            return delegate.getPrecision();
        }

        @Override
        public int getScale() {
            return delegate.getScale();
        }

        @Override
        public SqlTypeName getSqlTypeName() {
            return delegate.getSqlTypeName();
        }

        @Override
        public SqlIdentifier getSqlIdentifier() {
            return delegate.getSqlIdentifier();
        }

        @Override
        public String toString() {
            return delegate.toString();
        }

        @Override
        public String getFullTypeString() {
            return delegate.getFullTypeString();
        }

        @Override
        public RelDataTypeFamily getFamily() {
            return delegate.getFamily();
        }

        @Override
        public RelDataTypePrecedenceList getPrecedenceList() {
            return delegate.getPrecedenceList();
        }

        @Override
        public RelDataTypeComparability getComparability() {
            return delegate.getComparability();
        }

        @Override
        public boolean isDynamicStruct() {
            return delegate.isDynamicStruct();
        }
    }
}