All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.plugin.mongodb.MongoMetadata Maven / Gradle / Ivy

There is a newer version: 458
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.plugin.mongodb;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import com.google.common.io.Closer;
import com.mongodb.client.MongoCollection;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.plugin.base.projection.ApplyProjectionUtil;
import io.trino.plugin.mongodb.MongoIndex.MongodbIndexKey;
import io.trino.plugin.mongodb.ptf.Query.QueryFunctionHandle;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.Assignment;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorInsertTableHandle;
import io.trino.spi.connector.ConnectorMetadata;
import io.trino.spi.connector.ConnectorOutputMetadata;
import io.trino.spi.connector.ConnectorOutputTableHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.ConnectorTableLayout;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.ConnectorTableProperties;
import io.trino.spi.connector.ConnectorTableVersion;
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.LimitApplicationResult;
import io.trino.spi.connector.LocalProperty;
import io.trino.spi.connector.NotFoundException;
import io.trino.spi.connector.ProjectionApplicationResult;
import io.trino.spi.connector.RelationColumnsMetadata;
import io.trino.spi.connector.RetryMode;
import io.trino.spi.connector.SaveMode;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SchemaTablePrefix;
import io.trino.spi.connector.SortingProperty;
import io.trino.spi.connector.TableFunctionApplicationResult;
import io.trino.spi.connector.TableNotFoundException;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.Variable;
import io.trino.spi.function.table.ConnectorTableFunctionHandle;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.security.TrinoPrincipal;
import io.trino.spi.statistics.ComputedStatistics;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.RowType.Field;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import org.bson.Document;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.UnaryOperator;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static com.mongodb.client.model.Aggregates.lookup;
import static com.mongodb.client.model.Aggregates.match;
import static com.mongodb.client.model.Aggregates.merge;
import static com.mongodb.client.model.Aggregates.project;
import static com.mongodb.client.model.Filters.ne;
import static com.mongodb.client.model.Projections.exclude;
import static io.trino.plugin.base.TemporaryTables.generateTemporaryTableName;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.replaceWithNewVariables;
import static io.trino.plugin.mongodb.MongoSession.COLLECTION_NAME;
import static io.trino.plugin.mongodb.MongoSession.DATABASE_NAME;
import static io.trino.plugin.mongodb.MongoSession.ID;
import static io.trino.plugin.mongodb.MongoSessionProperties.isProjectionPushdownEnabled;
import static io.trino.plugin.mongodb.TypeUtils.isPushdownSupportedType;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.connector.RelationColumnsMetadata.forTable;
import static io.trino.spi.connector.SaveMode.REPLACE;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toList;

public class MongoMetadata
        implements ConnectorMetadata
{
    private static final Logger log = Logger.get(MongoMetadata.class);
    private static final Type TRINO_PAGE_SINK_ID_COLUMN_TYPE = BigintType.BIGINT;

    private static final int MAX_QUALIFIED_IDENTIFIER_BYTE_LENGTH = 120;

    private final MongoSession mongoSession;

    private final AtomicReference rollbackAction = new AtomicReference<>();

    public MongoMetadata(MongoSession mongoSession)
    {
        this.mongoSession = requireNonNull(mongoSession, "mongoSession is null");
    }

    @Override
    public List listSchemaNames(ConnectorSession session)
    {
        return mongoSession.getAllSchemas();
    }

    @Override
    public void createSchema(ConnectorSession session, String schemaName, Map properties, TrinoPrincipal owner)
    {
        checkArgument(properties.isEmpty(), "Can't have properties for schema creation");
        mongoSession.createSchema(schemaName);
    }

    @Override
    public void dropSchema(ConnectorSession session, String schemaName, boolean cascade)
    {
        mongoSession.dropSchema(schemaName, cascade);
    }

    @Override
    public MongoTableHandle getTableHandle(
            ConnectorSession session,
            SchemaTableName tableName,
            Optional startVersion,
            Optional endVersion)
    {
        if (startVersion.isPresent() || endVersion.isPresent()) {
            throw new TrinoException(NOT_SUPPORTED, "This connector does not support versioned tables");
        }

        requireNonNull(tableName, "tableName is null");
        try {
            return mongoSession.getTable(tableName).tableHandle();
        }
        catch (TableNotFoundException e) {
            log.debug(e, "Table(%s) not found", tableName);
            return null;
        }
    }

    @Override
    public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle)
    {
        requireNonNull(tableHandle, "tableHandle is null");
        SchemaTableName tableName = getTableName(tableHandle);
        return getTableMetadata(tableName);
    }

    @Override
    public List listTables(ConnectorSession session, Optional optionalSchemaName)
    {
        List schemaNames = optionalSchemaName.map(ImmutableList::of)
                .orElseGet(() -> (ImmutableList) listSchemaNames(session));
        ImmutableList.Builder tableNames = ImmutableList.builder();
        for (String schemaName : schemaNames) {
            for (String tableName : mongoSession.getAllTables(schemaName)) {
                tableNames.add(new SchemaTableName(schemaName, tableName.toLowerCase(ENGLISH)));
            }
        }
        return tableNames.build();
    }

    @Override
    public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle)
    {
        MongoTableHandle table = (MongoTableHandle) tableHandle;
        List columns = mongoSession.getTable(table.schemaTableName()).columns();

        ImmutableMap.Builder columnHandles = ImmutableMap.builder();
        for (MongoColumnHandle columnHandle : columns) {
            columnHandles.put(columnHandle.baseName().toLowerCase(ENGLISH), columnHandle);
        }
        return columnHandles.buildOrThrow();
    }

    @Override
    public Iterator streamRelationColumns(ConnectorSession session, Optional schemaName, UnaryOperator> relationFilter)
    {
        Map relationColumns = new HashMap<>();

        SchemaTablePrefix prefix = schemaName.map(SchemaTablePrefix::new)
                .orElseGet(SchemaTablePrefix::new);
        for (SchemaTableName tableName : listTables(session, prefix)) {
            try {
                relationColumns.put(tableName, forTable(tableName, getTableMetadata(tableName).getColumns()));
            }
            catch (NotFoundException e) {
                // table disappeared during listing operation
            }
        }

        return relationFilter.apply(relationColumns.keySet()).stream()
                .map(relationColumns::get)
                .iterator();
    }

    private List listTables(ConnectorSession session, SchemaTablePrefix prefix)
    {
        if (prefix.getTable().isEmpty()) {
            return listTables(session, prefix.getSchema());
        }
        return ImmutableList.of(prefix.toSchemaTableName());
    }

    @Override
    public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle)
    {
        return ((MongoColumnHandle) columnHandle).toColumnMetadata();
    }

    @Override
    public void createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, SaveMode saveMode)
    {
        if (saveMode == REPLACE) {
            throw new TrinoException(NOT_SUPPORTED, "This connector does not support replacing tables");
        }
        RemoteTableName remoteTableName = mongoSession.toRemoteSchemaTableName(tableMetadata.getTable());
        mongoSession.createTable(remoteTableName, buildColumnHandles(tableMetadata), tableMetadata.getComment());
    }

    @Override
    public void dropTable(ConnectorSession session, ConnectorTableHandle tableHandle)
    {
        MongoTableHandle table = (MongoTableHandle) tableHandle;

        mongoSession.dropTable(table.remoteTableName());
    }

    @Override
    public void setTableComment(ConnectorSession session, ConnectorTableHandle tableHandle, Optional comment)
    {
        MongoTableHandle table = (MongoTableHandle) tableHandle;
        mongoSession.setTableComment(table, comment);
    }

    @Override
    public void setColumnComment(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle, Optional comment)
    {
        MongoTableHandle table = (MongoTableHandle) tableHandle;
        MongoColumnHandle column = (MongoColumnHandle) columnHandle;
        mongoSession.setColumnComment(table, column.baseName(), comment);
    }

    @Override
    public void renameTable(ConnectorSession session, ConnectorTableHandle tableHandle, SchemaTableName newTableName)
    {
        if (newTableName.toString().getBytes(UTF_8).length > MAX_QUALIFIED_IDENTIFIER_BYTE_LENGTH) {
            throw new TrinoException(NOT_SUPPORTED, format("Qualified identifier name must be shorter than or equal to '%s' bytes: '%s'", MAX_QUALIFIED_IDENTIFIER_BYTE_LENGTH, newTableName));
        }
        MongoTableHandle table = (MongoTableHandle) tableHandle;
        mongoSession.renameTable(table, newTableName);
    }

    @Override
    public void addColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnMetadata column)
    {
        mongoSession.addColumn(((MongoTableHandle) tableHandle), column);
    }

    @Override
    public void renameColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle source, String target)
    {
        mongoSession.renameColumn(((MongoTableHandle) tableHandle), ((MongoColumnHandle) source).baseName(), target);
    }

    @Override
    public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle column)
    {
        mongoSession.dropColumn(((MongoTableHandle) tableHandle), ((MongoColumnHandle) column).baseName());
    }

    @Override
    public void setColumnType(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle, Type type)
    {
        MongoTableHandle table = (MongoTableHandle) tableHandle;
        MongoColumnHandle column = (MongoColumnHandle) columnHandle;
        if (!canChangeColumnType(column.type(), type)) {
            throw new TrinoException(NOT_SUPPORTED, "Cannot change type from %s to %s".formatted(column.type(), type));
        }
        mongoSession.setColumnType(table, column.baseName(), type);
    }

    private static boolean canChangeColumnType(Type sourceType, Type newType)
    {
        if (sourceType.equals(newType)) {
            return true;
        }
        if (sourceType == TINYINT) {
            return newType == SMALLINT || newType == INTEGER || newType == BIGINT;
        }
        if (sourceType == SMALLINT) {
            return newType == INTEGER || newType == BIGINT;
        }
        if (sourceType == INTEGER) {
            return newType == BIGINT;
        }
        if (sourceType == REAL) {
            return newType == DOUBLE;
        }
        if (sourceType instanceof VarcharType || sourceType instanceof CharType) {
            return newType instanceof VarcharType || newType instanceof CharType;
        }
        if (sourceType instanceof DecimalType sourceDecimal && newType instanceof DecimalType newDecimal) {
            return sourceDecimal.getScale() == newDecimal.getScale()
                    && sourceDecimal.getPrecision() <= newDecimal.getPrecision();
        }
        if (sourceType instanceof ArrayType sourceArrayType && newType instanceof ArrayType newArrayType) {
            return canChangeColumnType(sourceArrayType.getElementType(), newArrayType.getElementType());
        }
        if (sourceType instanceof RowType sourceRowType && newType instanceof RowType newRowType) {
            List fields = Streams.concat(sourceRowType.getFields().stream(), newRowType.getFields().stream())
                    .distinct()
                    .collect(toImmutableList());
            for (Field field : fields) {
                String fieldName = field.getName().orElseThrow();
                if (fieldExists(sourceRowType, fieldName) && fieldExists(newRowType, fieldName)) {
                    if (!canChangeColumnType(
                            findFieldByName(sourceRowType.getFields(), fieldName).getType(),
                            findFieldByName(newRowType.getFields(), fieldName).getType())) {
                        return false;
                    }
                }
            }
            return true;
        }
        return false;
    }

    private static Field findFieldByName(List fields, String fieldName)
    {
        return fields.stream()
                .filter(field -> field.getName().orElseThrow().equals(fieldName))
                .collect(onlyElement());
    }

    private static boolean fieldExists(RowType structType, String fieldName)
    {
        for (Field field : structType.getFields()) {
            if (field.getName().orElseThrow().equals(fieldName)) {
                return true;
            }
        }
        return false;
    }

    @Override
    public ConnectorOutputTableHandle beginCreateTable(
            ConnectorSession session,
            ConnectorTableMetadata tableMetadata,
            Optional layout,
            RetryMode retryMode,
            boolean replace)
    {
        if (replace) {
            throw new TrinoException(NOT_SUPPORTED, "This connector does not support replacing tables");
        }
        RemoteTableName remoteTableName = mongoSession.toRemoteSchemaTableName(tableMetadata.getTable());

        List columns = buildColumnHandles(tableMetadata);

        mongoSession.createTable(remoteTableName, columns, tableMetadata.getComment());

        List handleColumns = columns.stream().filter(column -> !column.hidden()).collect(toImmutableList());

        Closer closer = Closer.create();
        closer.register(() -> mongoSession.dropTable(remoteTableName));
        setRollback(() -> {
            try {
                closer.close();
            }
            catch (IOException e) {
                throw new TrinoException(GENERIC_INTERNAL_ERROR, e);
            }
        });

        if (retryMode == RetryMode.NO_RETRIES) {
            return new MongoOutputTableHandle(
                    remoteTableName,
                    handleColumns,
                    Optional.empty(),
                    Optional.empty());
        }

        MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(columns.stream().map(MongoColumnHandle::baseName).collect(toImmutableSet()));
        List allTemporaryTableColumns = ImmutableList.builderWithExpectedSize(columns.size() + 1)
                .addAll(columns)
                .add(pageSinkIdColumn)
                .build();
        RemoteTableName temporaryTable = new RemoteTableName(remoteTableName.databaseName(), generateTemporaryTableName(session));
        mongoSession.createTable(temporaryTable, allTemporaryTableColumns, Optional.empty());
        closer.register(() -> mongoSession.dropTable(temporaryTable));

        return new MongoOutputTableHandle(
                remoteTableName,
                handleColumns,
                Optional.of(temporaryTable.collectionName()),
                Optional.of(pageSinkIdColumn.baseName()));
    }

    @Override
    public Optional finishCreateTable(ConnectorSession session, ConnectorOutputTableHandle tableHandle, Collection fragments, Collection computedStatistics)
    {
        MongoOutputTableHandle handle = (MongoOutputTableHandle) tableHandle;
        if (handle.temporaryTableName().isPresent()) {
            finishInsert(session, handle.remoteTableName(), handle.getTemporaryRemoteTableName().get(), handle.pageSinkIdColumnName().get(), fragments);
        }
        clearRollback();
        return Optional.empty();
    }

    @Override
    public ConnectorInsertTableHandle beginInsert(ConnectorSession session, ConnectorTableHandle tableHandle, List insertedColumns, RetryMode retryMode)
    {
        MongoTable table = mongoSession.getTable(((MongoTableHandle) tableHandle).schemaTableName());
        MongoTableHandle handle = table.tableHandle();
        List columns = table.columns();
        List handleColumns = columns.stream()
                .filter(column -> !column.hidden())
                .peek(column -> validateColumnNameForInsert(column.baseName()))
                .collect(toImmutableList());

        if (retryMode == RetryMode.NO_RETRIES) {
            return new MongoInsertTableHandle(
                    handle.remoteTableName(),
                    handleColumns,
                    Optional.empty(),
                    Optional.empty());
        }
        MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(columns.stream().map(MongoColumnHandle::baseName).collect(toImmutableSet()));
        List allColumns = ImmutableList.builderWithExpectedSize(columns.size() + 1)
                .addAll(columns)
                .add(pageSinkIdColumn)
                .build();

        RemoteTableName temporaryTable = new RemoteTableName(handle.schemaTableName().getSchemaName(), generateTemporaryTableName(session));
        mongoSession.createTable(temporaryTable, allColumns, Optional.empty());

        setRollback(() -> mongoSession.dropTable(temporaryTable));

        return new MongoInsertTableHandle(
                handle.remoteTableName(),
                handleColumns,
                Optional.of(temporaryTable.collectionName()),
                Optional.of(pageSinkIdColumn.baseName()));
    }

    @Override
    public Optional finishInsert(
            ConnectorSession session,
            ConnectorInsertTableHandle insertHandle,
            List sourceTableHandles,
            Collection fragments,
            Collection computedStatistics)
    {
        MongoInsertTableHandle handle = (MongoInsertTableHandle) insertHandle;
        if (handle.temporaryTableName().isPresent()) {
            finishInsert(session, handle.remoteTableName(), handle.getTemporaryRemoteTableName().get(), handle.pageSinkIdColumnName().get(), fragments);
        }
        clearRollback();
        return Optional.empty();
    }

    private void finishInsert(
            ConnectorSession session,
            RemoteTableName targetTable,
            RemoteTableName temporaryTable,
            String pageSinkIdColumnName,
            Collection fragments)
    {
        Closer closer = Closer.create();
        closer.register(() -> mongoSession.dropTable(temporaryTable));

        try {
            // Create the temporary page sink ID table
            RemoteTableName pageSinkIdsTable = new RemoteTableName(temporaryTable.databaseName(), generateTemporaryTableName(session));
            MongoColumnHandle pageSinkIdColumn = new MongoColumnHandle(pageSinkIdColumnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty());
            mongoSession.createTable(pageSinkIdsTable, ImmutableList.of(pageSinkIdColumn), Optional.empty());
            closer.register(() -> mongoSession.dropTable(pageSinkIdsTable));

            // Insert all the page sink IDs into the page sink ID table
            MongoCollection pageSinkIdsCollection = mongoSession.getCollection(pageSinkIdsTable);
            List pageSinkIds = fragments.stream()
                    .map(slice -> new Document(pageSinkIdColumnName, slice.getLong(0)))
                    .collect(toImmutableList());
            pageSinkIdsCollection.insertMany(pageSinkIds);

            MongoCollection temporaryCollection = mongoSession.getCollection(temporaryTable);
            temporaryCollection.aggregate(ImmutableList.of(
                    lookup(pageSinkIdsTable.collectionName(), pageSinkIdColumnName, pageSinkIdColumnName, "page_sink_id"),
                    match(ne("page_sink_id", ImmutableList.of())),
                    project(exclude("page_sink_id")),
                    merge(targetTable.collectionName())))
                    .toCollection();
        }
        finally {
            try {
                closer.close();
            }
            catch (IOException e) {
                throw new TrinoException(GENERIC_INTERNAL_ERROR, e);
            }
        }
    }

    @Override
    public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle)
    {
        return new MongoColumnHandle("$merge_row_id", ImmutableList.of(), BIGINT, true, false, Optional.empty());
    }

    @Override
    public Optional applyDelete(ConnectorSession session, ConnectorTableHandle handle)
    {
        return Optional.of(handle);
    }

    @Override
    public OptionalLong executeDelete(ConnectorSession session, ConnectorTableHandle handle)
    {
        MongoTableHandle table = (MongoTableHandle) handle;
        return OptionalLong.of(mongoSession.deleteDocuments(table.remoteTableName(), table.constraint()));
    }

    @Override
    public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table)
    {
        MongoTableHandle tableHandle = (MongoTableHandle) table;

        ImmutableList.Builder> localProperties = ImmutableList.builder();

        MongoTable tableInfo = mongoSession.getTable(tableHandle.schemaTableName());
        Map columns = getColumnHandles(session, tableHandle);

        for (MongoIndex index : tableInfo.indexes()) {
            for (MongodbIndexKey key : index.getKeys()) {
                if (key.getSortOrder().isEmpty()) {
                    continue;
                }
                if (columns.get(key.getName()) != null) {
                    localProperties.add(new SortingProperty<>(columns.get(key.getName()), key.getSortOrder().get()));
                }
            }
        }

        return new ConnectorTableProperties(
                TupleDomain.all(),
                Optional.empty(),
                Optional.empty(),
                localProperties.build());
    }

    @Override
    public Optional> applyLimit(ConnectorSession session, ConnectorTableHandle table, long limit)
    {
        MongoTableHandle handle = (MongoTableHandle) table;

        // MongoDB cursor.limit(0) is equivalent to setting no limit
        if (limit == 0) {
            return Optional.empty();
        }

        // MongoDB doesn't support limit number greater than integer max
        if (limit > Integer.MAX_VALUE) {
            return Optional.empty();
        }

        if (handle.limit().isPresent() && handle.limit().getAsInt() <= limit) {
            return Optional.empty();
        }

        return Optional.of(new LimitApplicationResult<>(
                new MongoTableHandle(
                        handle.schemaTableName(),
                        handle.remoteTableName(),
                        handle.filter(),
                        handle.constraint(),
                        handle.projectedColumns(),
                        OptionalInt.of(toIntExact(limit))),
                true,
                false));
    }

    @Override
    public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint)
    {
        MongoTableHandle handle = (MongoTableHandle) table;

        TupleDomain oldDomain = handle.constraint();
        TupleDomain newDomain = oldDomain.intersect(constraint.getSummary());
        TupleDomain remainingFilter;
        if (newDomain.isNone()) {
            remainingFilter = TupleDomain.all();
        }
        else {
            Map domains = newDomain.getDomains().orElseThrow();

            Map supported = new HashMap<>();
            Map unsupported = new HashMap<>();

            for (Map.Entry entry : domains.entrySet()) {
                MongoColumnHandle columnHandle = (MongoColumnHandle) entry.getKey();
                Domain domain = entry.getValue();
                Type columnType = columnHandle.type();
                // TODO: Support predicate pushdown on more types including JSON
                if (isPushdownSupportedType(columnType)) {
                    supported.put(entry.getKey(), entry.getValue());
                }
                else {
                    unsupported.put(columnHandle, domain);
                }
            }
            newDomain = TupleDomain.withColumnDomains(supported);
            remainingFilter = TupleDomain.withColumnDomains(unsupported);
        }

        if (oldDomain.equals(newDomain)) {
            return Optional.empty();
        }

        handle = new MongoTableHandle(
                handle.schemaTableName(),
                handle.remoteTableName(),
                handle.filter(),
                newDomain,
                handle.projectedColumns(),
                handle.limit());

        return Optional.of(new ConstraintApplicationResult<>(handle, remainingFilter, constraint.getExpression(), false));
    }

    @Override
    public Optional> applyProjection(
            ConnectorSession session,
            ConnectorTableHandle handle,
            List projections,
            Map assignments)
    {
        if (!isProjectionPushdownEnabled(session)) {
            return Optional.empty();
        }
        // Create projected column representations for supported sub expressions. Simple column references and chain of
        // dereferences on a variable are supported right now.
        Set projectedExpressions = projections.stream()
                .flatMap(expression -> extractSupportedProjectedColumns(expression, MongoMetadata::isSupportedForPushdown).stream())
                .collect(toImmutableSet());

        Map columnProjections = projectedExpressions.stream()
                .collect(toImmutableMap(identity(), ApplyProjectionUtil::createProjectedColumnRepresentation));

        MongoTableHandle mongoTableHandle = (MongoTableHandle) handle;

        // all references are simple variables
        if (columnProjections.values().stream().allMatch(ProjectedColumnRepresentation::isVariable)) {
            Set projectedColumns = assignments.values().stream()
                    .map(MongoColumnHandle.class::cast)
                    .collect(toImmutableSet());
            if (mongoTableHandle.projectedColumns().equals(projectedColumns)) {
                return Optional.empty();
            }
            List assignmentsList = assignments.entrySet().stream()
                    .map(assignment -> new Assignment(
                            assignment.getKey(),
                            assignment.getValue(),
                            ((MongoColumnHandle) assignment.getValue()).type()))
                    .collect(toImmutableList());

            return Optional.of(new ProjectionApplicationResult<>(
                    mongoTableHandle.withProjectedColumns(projectedColumns),
                    projections,
                    assignmentsList,
                    false));
        }

        Map newAssignments = new HashMap<>();
        ImmutableMap.Builder newVariablesBuilder = ImmutableMap.builder();
        ImmutableSet.Builder projectedColumnsBuilder = ImmutableSet.builder();

        for (Map.Entry entry : columnProjections.entrySet()) {
            ConnectorExpression expression = entry.getKey();
            ProjectedColumnRepresentation projectedColumn = entry.getValue();

            MongoColumnHandle baseColumnHandle = (MongoColumnHandle) assignments.get(projectedColumn.getVariable().getName());
            MongoColumnHandle projectedColumnHandle = projectColumn(baseColumnHandle, projectedColumn.getDereferenceIndices(), expression.getType());
            String projectedColumnName = projectedColumnHandle.getQualifiedName();

            Variable projectedColumnVariable = new Variable(projectedColumnName, expression.getType());
            Assignment newAssignment = new Assignment(projectedColumnName, projectedColumnHandle, expression.getType());
            newAssignments.putIfAbsent(projectedColumnName, newAssignment);

            newVariablesBuilder.put(expression, projectedColumnVariable);
            projectedColumnsBuilder.add(projectedColumnHandle);
        }

        // Modify projections to refer to new variables
        Map newVariables = newVariablesBuilder.buildOrThrow();
        List newProjections = projections.stream()
                .map(expression -> replaceWithNewVariables(expression, newVariables))
                .collect(toImmutableList());

        List outputAssignments = newAssignments.values().stream().collect(toImmutableList());
        return Optional.of(new ProjectionApplicationResult<>(
                mongoTableHandle.withProjectedColumns(projectedColumnsBuilder.build()),
                newProjections,
                outputAssignments,
                false));
    }

    private static boolean isSupportedForPushdown(ConnectorExpression connectorExpression)
    {
        if (connectorExpression instanceof Variable) {
            return true;
        }
        if (connectorExpression instanceof FieldDereference fieldDereference) {
            RowType rowType = (RowType) fieldDereference.getTarget().getType();
            if (isDBRefField(rowType)) {
                return false;
            }
            Field field = rowType.getFields().get(fieldDereference.getField());
            if (field.getName().isEmpty()) {
                return false;
            }
            String fieldName = field.getName().get();
            if (fieldName.contains(".") || fieldName.contains("$")) {
                return false;
            }
            return true;
        }
        return false;
    }

    private static MongoColumnHandle projectColumn(MongoColumnHandle baseColumn, List indices, Type projectedColumnType)
    {
        if (indices.isEmpty()) {
            return baseColumn;
        }
        ImmutableList.Builder dereferenceNamesBuilder = ImmutableList.builder();
        dereferenceNamesBuilder.addAll(baseColumn.dereferenceNames());

        Type type = baseColumn.type();
        RowType parentType = null;
        for (int index : indices) {
            checkArgument(type instanceof RowType, "type should be Row type");
            RowType rowType = (RowType) type;
            Field field = rowType.getFields().get(index);
            dereferenceNamesBuilder.add(field.getName()
                    .orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "ROW type does not have field names declared: " + rowType)));
            parentType = rowType;
            type = field.getType();
        }
        return new MongoColumnHandle(
                baseColumn.baseName(),
                dereferenceNamesBuilder.build(),
                projectedColumnType,
                baseColumn.hidden(),
                isDBRefField(parentType),
                baseColumn.comment());
    }

    /**
     * This method may return a wrong flag when row type use the same field names and types as dbref.
     */
    private static boolean isDBRefField(Type type)
    {
        if (!(type instanceof RowType rowType)) {
            return false;
        }
        requireNonNull(type, "type is null");
        // When projected field is inside DBRef type field
        List fields = rowType.getFields();
        if (fields.size() != 3) {
            return false;
        }
        return fields.get(0).getName().orElseThrow().equals(DATABASE_NAME)
                && fields.get(0).getType().equals(VARCHAR)
                && fields.get(1).getName().orElseThrow().equals(COLLECTION_NAME)
                && fields.get(1).getType().equals(VARCHAR)
                && fields.get(2).getName().orElseThrow().equals(ID);
               // Id type can be of any type
    }

    @Override
    public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle)
    {
        if (!(handle instanceof QueryFunctionHandle)) {
            return Optional.empty();
        }

        ConnectorTableHandle tableHandle = ((QueryFunctionHandle) handle).getTableHandle();
        List columnHandles = getColumnHandles(session, tableHandle).values().stream()
                .filter(column -> !((MongoColumnHandle) column).hidden())
                .collect(toImmutableList());
        return Optional.of(new TableFunctionApplicationResult<>(tableHandle, columnHandles));
    }

    private void setRollback(Runnable action)
    {
        checkState(rollbackAction.compareAndSet(null, action), "rollback action is already set");
    }

    private void clearRollback()
    {
        rollbackAction.set(null);
    }

    public void rollback()
    {
        Optional.ofNullable(rollbackAction.getAndSet(null)).ifPresent(Runnable::run);
    }

    private static SchemaTableName getTableName(ConnectorTableHandle tableHandle)
    {
        return ((MongoTableHandle) tableHandle).schemaTableName();
    }

    private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName)
    {
        MongoTable mongoTable = mongoSession.getTable(tableName);

        List columns = mongoTable.columns().stream()
                .map(MongoColumnHandle::toColumnMetadata)
                .collect(toImmutableList());

        return new ConnectorTableMetadata(tableName, columns, ImmutableMap.of(), mongoTable.comment());
    }

    private static List buildColumnHandles(ConnectorTableMetadata tableMetadata)
    {
        return tableMetadata.getColumns().stream()
                .map(m -> new MongoColumnHandle(m.getName(), ImmutableList.of(), m.getType(), m.isHidden(), false, Optional.ofNullable(m.getComment())))
                .collect(toList());
    }

    private static void validateColumnNameForInsert(String columnName)
    {
        if (columnName.contains("$") || columnName.contains(".")) {
            throw new IllegalArgumentException("Column name must not contain '$' or '.' for INSERT: " + columnName);
        }
    }

    private static MongoColumnHandle buildPageSinkIdColumn(Set otherColumnNames)
    {
        // While it's unlikely this column name will collide with client table columns,
        // guarantee it will not by appending a deterministic suffix to it.
        String baseColumnName = "trino_page_sink_id";
        String columnName = baseColumnName;
        int suffix = 1;
        while (otherColumnNames.contains(columnName)) {
            columnName = baseColumnName + "_" + suffix;
            suffix++;
        }
        return new MongoColumnHandle(columnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty());
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy