All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.plugin.hive.metastore.glue.InMemoryGlueCache Maven / Gradle / Ivy

There is a newer version: 468
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.plugin.hive.metastore.glue;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import io.airlift.jmx.CacheStatsMBean;
import io.airlift.units.Duration;
import io.trino.cache.SafeCaches;
import io.trino.plugin.hive.metastore.Database;
import io.trino.plugin.hive.metastore.HiveColumnStatistics;
import io.trino.plugin.hive.metastore.Partition;
import io.trino.plugin.hive.metastore.Table;
import io.trino.plugin.hive.metastore.TableInfo;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.function.LanguageFunction;
import org.gaul.modernizer_maven_annotations.SuppressModernizer;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiFunction;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import static java.util.concurrent.TimeUnit.MILLISECONDS;

class InMemoryGlueCache
        implements GlueCache
{
    private enum Global { GLOBAL }

    private record PartitionKey(String databaseName, String tableName, PartitionName partitionName) {}
    private record PartitionNamesKey(String databaseName, String tableName, String glueFilterExpression) {}
    private record FunctionKey(String databaseName, String functionName) {}

    private final LoadingCache>> databaseNamesCache;
    private final LoadingCache>> databaseCache;
    private final LoadingCache>> tableNamesCache;
    private final LoadingCache>> tableCache;
    private final LoadingCache tableColumnStatsCache;
    private final LoadingCache>> partitionNamesCache;
    private final LoadingCache>> partitionCache;
    private final LoadingCache partitionColumnStatsCache;
    private final LoadingCache>> allFunctionsCache;
    private final LoadingCache>> functionCache;

    private final AtomicLong databaseInvalidationCounter = new AtomicLong();
    private final AtomicLong tableInvalidationCounter = new AtomicLong();
    private final AtomicLong partitionInvalidationCounter = new AtomicLong();

    public InMemoryGlueCache(Duration metadataCacheTtl, Duration statsCacheTtl, long maximumSize)
    {
        OptionalLong metadataCacheTtlMillis = OptionalLong.of(metadataCacheTtl.toMillis());
        this.databaseNamesCache = buildCache(metadataCacheTtlMillis, maximumSize, ValueHolder::new);
        this.databaseCache = buildCache(metadataCacheTtlMillis, maximumSize, ValueHolder::new);
        this.tableNamesCache = buildCache(metadataCacheTtlMillis, maximumSize, ValueHolder::new);
        this.tableCache = buildCache(metadataCacheTtlMillis, maximumSize, ValueHolder::new);
        this.partitionNamesCache = buildCache(metadataCacheTtlMillis, maximumSize, ValueHolder::new);
        this.partitionCache = buildCache(metadataCacheTtlMillis, maximumSize, ValueHolder::new);
        this.allFunctionsCache = buildCache(metadataCacheTtlMillis, maximumSize, ValueHolder::new);
        this.functionCache = buildCache(metadataCacheTtlMillis, maximumSize, ValueHolder::new);

        OptionalLong statsCacheTtlMillis = OptionalLong.of(statsCacheTtl.toMillis());
        this.tableColumnStatsCache = buildCache(statsCacheTtlMillis, maximumSize, ColumnStatisticsHolder::new);
        this.partitionColumnStatsCache = buildCache(statsCacheTtlMillis, maximumSize, ColumnStatisticsHolder::new);
    }

    @Override
    public List getDatabaseNames(Function, List> loader)
    {
        long invalidationCounter = databaseInvalidationCounter.get();
        return databaseNamesCache.getUnchecked(Global.GLOBAL).getValue(() -> loader.apply(database -> cacheDatabase(invalidationCounter, database)));
    }

    private void cacheDatabase(long invalidationCounter, Database database)
    {
        cacheValue(databaseCache, database.getDatabaseName(), Optional.of(database), () -> invalidationCounter == databaseInvalidationCounter.get());
    }

    @Override
    public void invalidateDatabase(String databaseName)
    {
        databaseInvalidationCounter.incrementAndGet();
        databaseCache.invalidate(databaseName);
        for (SchemaTableName schemaTableName : Sets.union(tableCache.asMap().keySet(), tableColumnStatsCache.asMap().keySet())) {
            if (schemaTableName.getSchemaName().equals(databaseName)) {
                invalidateTable(schemaTableName.getSchemaName(), schemaTableName.getTableName(), true);
            }
        }
        for (PartitionKey partitionKey : Sets.union(partitionCache.asMap().keySet(), partitionColumnStatsCache.asMap().keySet())) {
            if (partitionKey.databaseName().equals(databaseName)) {
                invalidatePartition(partitionKey);
            }
        }
        for (PartitionNamesKey partitionNamesKey : partitionNamesCache.asMap().keySet()) {
            if (partitionNamesKey.databaseName().equals(databaseName)) {
                partitionNamesCache.invalidate(partitionNamesKey);
            }
        }
        for (FunctionKey functionKey : functionCache.asMap().keySet()) {
            if (functionKey.databaseName().equals(databaseName)) {
                functionCache.invalidate(functionKey);
            }
        }
        allFunctionsCache.invalidate(databaseName);
    }

    @Override
    public void invalidateDatabaseNames()
    {
        databaseNamesCache.invalidate(Global.GLOBAL);
    }

    @Override
    public Optional getDatabase(String databaseName, Supplier> loader)
    {
        return databaseCache.getUnchecked(databaseName).getValue(loader);
    }

    @Override
    public List getTables(String databaseName, Function, List> loader)
    {
        long invalidationCounter = tableInvalidationCounter.get();
        return tableNamesCache.getUnchecked(databaseName).getValue(() -> loader.apply(table -> cacheTable(invalidationCounter, table)));
    }

    private void cacheTable(long invalidationCounter, Table table)
    {
        cacheValue(tableCache, table.getSchemaTableName(), Optional.of(table), () -> invalidationCounter == tableInvalidationCounter.get());
    }

    @Override
    public void invalidateTables(String databaseName)
    {
        tableNamesCache.invalidate(databaseName);
    }

    @Override
    public Optional getTable(String databaseName, String tableName, Supplier> loader)
    {
        return tableCache.getUnchecked(new SchemaTableName(databaseName, tableName)).getValue(loader);
    }

    @Override
    public void invalidateTable(String databaseName, String tableName, boolean cascade)
    {
        tableInvalidationCounter.incrementAndGet();
        SchemaTableName schemaTableName = new SchemaTableName(databaseName, tableName);
        tableCache.invalidate(schemaTableName);
        tableColumnStatsCache.invalidate(schemaTableName);
        if (cascade) {
            for (PartitionKey partitionKey : Sets.union(partitionCache.asMap().keySet(), partitionColumnStatsCache.asMap().keySet())) {
                if (partitionKey.databaseName().equals(databaseName) && partitionKey.tableName().equals(tableName)) {
                    invalidatePartition(partitionKey);
                }
            }
            invalidatePartitionNames(databaseName, tableName);
        }
    }

    @Override
    public Map getTableColumnStatistics(String databaseName, String tableName, Set columnNames, Function, Map> loader)
    {
        return tableColumnStatsCache.getUnchecked(new SchemaTableName(databaseName, tableName))
                .getColumnStatistics(columnNames, loader);
    }

    @Override
    public void invalidateTableColumnStatistics(String databaseName, String tableName)
    {
        SchemaTableName schemaTableName = new SchemaTableName(databaseName, tableName);
        tableColumnStatsCache.invalidate(schemaTableName);
    }

    @Override
    public Set getPartitionNames(String databaseName, String tableName, String glueExpression, Function, Set> loader)
    {
        long invalidationCounter = partitionInvalidationCounter.get();
        return partitionNamesCache.getUnchecked(new PartitionNamesKey(databaseName, tableName, glueExpression))
                .getValue(() -> loader.apply(partition -> cachePartition(invalidationCounter, partition)));
    }

    private void invalidatePartitionNames(String databaseName, String tableName)
    {
        for (PartitionNamesKey partitionNamesKey : partitionNamesCache.asMap().keySet()) {
            if (partitionNamesKey.databaseName().equals(databaseName) && partitionNamesKey.tableName().equals(tableName)) {
                partitionNamesCache.invalidate(partitionNamesKey);
            }
        }
    }

    @Override
    public Optional getPartition(String databaseName, String tableName, PartitionName partitionName, Supplier> loader)
    {
        return partitionCache.getUnchecked(new PartitionKey(databaseName, tableName, partitionName)).getValue(loader);
    }

    @Override
    public Collection batchGetPartitions(
            String databaseName,
            String tableName,
            Collection partitionNames,
            BiFunction, Collection, Collection> loader)
    {
        ImmutableList.Builder partitions = ImmutableList.builder();
        Set missingPartitionNames = new HashSet<>();
        for (PartitionName partitionName : partitionNames) {
            ValueHolder> valueHolder = partitionCache.getIfPresent(new PartitionKey(databaseName, tableName, partitionName));
            if (valueHolder != null) {
                Optional partition = valueHolder.getValueIfPresent().flatMap(Function.identity());
                if (partition.isPresent()) {
                    partitions.add(partition.get());
                    continue;
                }
            }
            missingPartitionNames.add(partitionName);
        }
        if (!missingPartitionNames.isEmpty()) {
            // NOTE: loader is expected to directly insert the partitions into the cache, so there is no need to do it here
            long invalidationCounter = partitionInvalidationCounter.get();
            partitions.addAll(loader.apply(partition -> cachePartition(invalidationCounter, partition), missingPartitionNames));
        }
        return partitions.build();
    }

    private void cachePartition(long invalidationCounter, Partition partition)
    {
        PartitionKey partitionKey = new PartitionKey(partition.getDatabaseName(), partition.getTableName(), new PartitionName(partition.getValues()));
        cacheValue(partitionCache, partitionKey, Optional.of(partition), () -> invalidationCounter == partitionInvalidationCounter.get());
    }

    @Override
    public void invalidatePartition(String databaseName, String tableName, PartitionName partitionName)
    {
        invalidatePartition(new PartitionKey(databaseName, tableName, partitionName));
    }

    private void invalidatePartition(PartitionKey partitionKey)
    {
        partitionInvalidationCounter.incrementAndGet();
        partitionCache.invalidate(partitionKey);
        partitionColumnStatsCache.invalidate(partitionKey);
    }

    @Override
    public Map getPartitionColumnStatistics(
            String databaseName,
            String tableName,
            PartitionName partitionName,
            Set columnNames,
            Function, Map> loader)
    {
        return partitionColumnStatsCache.getUnchecked(new PartitionKey(databaseName, tableName, partitionName))
                .getColumnStatistics(columnNames, loader);
    }

    @Override
    public Collection getAllFunctions(String databaseName, Supplier> loader)
    {
        return allFunctionsCache.getUnchecked(databaseName).getValue(loader);
    }

    @Override
    public Collection getFunction(String databaseName, String functionName, Supplier> loader)
    {
        return functionCache.getUnchecked(new FunctionKey(databaseName, functionName)).getValue(loader);
    }

    @Override
    public void invalidateFunction(String databaseName, String functionName)
    {
        functionCache.invalidate(new FunctionKey(databaseName, functionName));
        allFunctionsCache.invalidate(databaseName);
    }

    @Managed
    @Nested
    public CacheStatsMBean getDatabaseNamesCacheStats()
    {
        return new CacheStatsMBean(databaseNamesCache);
    }

    @Managed
    @Nested
    public CacheStatsMBean getDatabaseCacheStats()
    {
        return new CacheStatsMBean(databaseCache);
    }

    @Managed
    @Nested
    public CacheStatsMBean getTableNamesCacheStats()
    {
        return new CacheStatsMBean(tableNamesCache);
    }

    @Managed
    @Nested
    public CacheStatsMBean getTableCacheStats()
    {
        return new CacheStatsMBean(tableCache);
    }

    @Managed
    @Nested
    public CacheStatsMBean getTableColumnStatsCacheStats()
    {
        return new CacheStatsMBean(tableColumnStatsCache);
    }

    @Managed
    @Nested
    public CacheStatsMBean getPartitionNamesCacheStats()
    {
        return new CacheStatsMBean(partitionNamesCache);
    }

    @Managed
    @Nested
    public CacheStatsMBean getPartitionCacheStats()
    {
        return new CacheStatsMBean(partitionCache);
    }

    @Managed
    @Nested
    public CacheStatsMBean getPartitionColumnStatsCacheStats()
    {
        return new CacheStatsMBean(partitionColumnStatsCache);
    }

    @Managed
    @Nested
    public CacheStatsMBean getAllFunctionsCacheStats()
    {
        return new CacheStatsMBean(allFunctionsCache);
    }

    @Managed
    @Nested
    public CacheStatsMBean getFunctionCacheStats()
    {
        return new CacheStatsMBean(functionCache);
    }

    @SuppressModernizer
    private static  LoadingCache buildCache(OptionalLong expiresAfterWriteMillis, long maximumSize, Supplier loader)
    {
        if (expiresAfterWriteMillis.isEmpty()) {
            return SafeCaches.emptyLoadingCache(CacheLoader.from(ignores -> loader.get()), true);
        }

        // this does not use EvictableCache because we want to inject values directly into the cache,
        // and we want a lock per key, instead of striped locks
        return CacheBuilder.newBuilder()
                .expireAfterWrite(expiresAfterWriteMillis.getAsLong(), MILLISECONDS)
                .maximumSize(maximumSize)
                .recordStats()
                .build(CacheLoader.from(loader::get));
    }

    private static  void cacheValue(LoadingCache> cache, K key, V value, BooleanSupplier test)
    {
        // get the current value before checking the invalidation counter
        ValueHolder valueHolder = cache.getUnchecked(key);
        if (!test.getAsBoolean()) {
            return;
        }
        // at this point, we know our value is ok to use in the value cache we fetched before the check
        valueHolder.tryOverwrite(value);
        // The value is updated, but Guava does not know the update happened, so the expiration time is not extended.
        // We need to replace the value in the cache to extend the expiration time iff this is still the latest value.
        cache.asMap().replace(key, valueHolder, valueHolder);
    }

    private static class ValueHolder
    {
        private final Lock writeLock = new ReentrantLock();
        private volatile V value;

        public ValueHolder() {}

        public V getValue(Supplier loader)
        {
            if (value == null) {
                writeLock.lock();
                try {
                    if (value == null) {
                        value = loader.get();
                        if (value == null) {
                            throw new IllegalStateException("Value loader returned null");
                        }
                    }
                }
                finally {
                    writeLock.unlock();
                }
            }
            return value;
        }

        public Optional getValueIfPresent()
        {
            return Optional.ofNullable(value);
        }

        /**
         * Overwrite the value unless it is currently being loaded by another thread.
         */
        public void tryOverwrite(V value)
        {
            if (writeLock.tryLock()) {
                try {
                    this.value = value;
                }
                finally {
                    writeLock.unlock();
                }
            }
        }
    }

    private static class ColumnStatisticsHolder
    {
        private final Lock writeLock = new ReentrantLock();
        private final Map> cache = new ConcurrentHashMap<>();

        public Map getColumnStatistics(Set columnNames, Function, Map> loader)
        {
            Set missingColumnNames = new HashSet<>();
            Map result = new ConcurrentHashMap<>();
            for (String columnName : columnNames) {
                Optional columnStatistics = cache.get(columnName);
                if (columnStatistics == null) {
                    missingColumnNames.add(columnName);
                }
                else {
                    columnStatistics.ifPresent(value -> result.put(columnName, value));
                }
            }
            if (!missingColumnNames.isEmpty()) {
                writeLock.lock();
                try {
                    Map loadedColumnStatistics = loader.apply(missingColumnNames);
                    for (String missingColumnName : missingColumnNames) {
                        HiveColumnStatistics value = loadedColumnStatistics.get(missingColumnName);
                        cache.put(missingColumnName, Optional.ofNullable(value));
                        if (value != null) {
                            result.put(missingColumnName, value);
                        }
                    }
                }
                finally {
                    writeLock.unlock();
                }
            }
            return result;
        }
    }
}