All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.zonky.test.db.provider.common.PrefetchingDatabaseProvider Maven / Gradle / Ivy

Go to download

A library for creating isolated embedded databases for Spring-powered integration tests.

The newest version!
/*
 * Copyright 2020 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.zonky.test.db.provider.common;

import com.google.common.base.MoreObjects;
import com.google.common.base.Stopwatch;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import io.zonky.test.db.preparer.CompositeDatabasePreparer;
import io.zonky.test.db.preparer.DatabasePreparer;
import io.zonky.test.db.provider.DatabaseProvider;
import io.zonky.test.db.provider.EmbeddedDatabase;
import io.zonky.test.db.provider.ProviderException;
import io.zonky.test.db.util.RandomStringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.util.concurrent.ListenableFutureTask;

import java.util.Comparator;
import java.util.List;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.collect.Maps.immutableEntry;
import static io.zonky.test.db.provider.common.PrefetchingDatabaseProvider.DatabasePipeline.State.INITIALIZED;
import static io.zonky.test.db.provider.common.PrefetchingDatabaseProvider.DatabasePipeline.State.INITIALIZING;
import static io.zonky.test.db.provider.common.PrefetchingDatabaseProvider.DatabasePipeline.State.NEW;
import static io.zonky.test.db.provider.common.PrefetchingDatabaseProvider.PrefetchingTask.TaskType.EXISTING_DATABASE;
import static io.zonky.test.db.provider.common.PrefetchingDatabaseProvider.PrefetchingTask.TaskType.NEW_DATABASE;
import static java.util.Collections.newSetFromMap;
import static java.util.stream.Collectors.toList;
import static org.springframework.core.Ordered.HIGHEST_PRECEDENCE;
import static org.springframework.core.Ordered.LOWEST_PRECEDENCE;

public class PrefetchingDatabaseProvider implements DatabaseProvider {

    private static final Logger logger = LoggerFactory.getLogger(PrefetchingDatabaseProvider.class);

    protected static final ThreadPoolTaskExecutor taskExecutor = new PriorityThreadPoolTaskExecutor();
    protected static final ConcurrentMap pipelines = new ConcurrentHashMap<>();
    protected static final AtomicLong databaseCount = new AtomicLong();

    static {
        taskExecutor.setThreadNamePrefix("prefetching-");
        taskExecutor.setAllowCoreThreadTimeOut(true);
        taskExecutor.setKeepAliveSeconds(60);
        taskExecutor.setCorePoolSize(1);
        taskExecutor.initialize();
    }

    protected final DatabaseProvider provider;
    protected final Config config;

    public PrefetchingDatabaseProvider(DatabaseProvider provider) {
        this(provider, Config.builder().build());
    }

    public PrefetchingDatabaseProvider(DatabaseProvider provider, Config config) {
        this.provider = provider;
        this.config = config;

        taskExecutor.setThreadNamePrefix(config.getThreadNamePrefix());
        taskExecutor.setCorePoolSize(config.getConcurrency());
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        PrefetchingDatabaseProvider that = (PrefetchingDatabaseProvider) o;
        return Objects.equals(provider, that.provider) &&
                Objects.equals(config, that.config);
    }

    @Override
    public int hashCode() {
        return Objects.hash(provider, config);
    }

    @Override
    public EmbeddedDatabase createDatabase(DatabasePreparer preparer) throws ProviderException {
        Stopwatch stopwatch = Stopwatch.createStarted();
        logger.trace("Prefetching pipelines: {}", pipelines.values());
        databaseCount.decrementAndGet();

        PipelineKey key = new PipelineKey(provider, preparer);
        DatabasePipeline pipeline = pipelines.computeIfAbsent(key, k -> new DatabasePipeline());
        PreparedResult result = pipeline.results.poll();

        if (result != null) {
            prepareDatabase(key, LOWEST_PRECEDENCE);
        } else {
            boolean pipelineInitMode = pipeline.state.compareAndSet(NEW, INITIALIZING);
            Optional task = prepareExistingDatabase(key, HIGHEST_PRECEDENCE);
            if (pipelineInitMode || !task.isPresent()) {
                prepareNewDatabase(key, HIGHEST_PRECEDENCE);
            }
        }

        long invocationCount = pipeline.requests.incrementAndGet();
        long databasesCount = pipeline.tasks.size() + pipeline.results.size();
        if (result == null) databasesCount--;

        if (databasesCount < invocationCount - 1 && databasesCount < config.getPipelineMaxCacheSize()) {
            prepareDatabase(key, -1);
        }
        reschedulePipeline(key);

        if (result == null) {
            try {
                result = pipeline.results.take();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new ProviderException("Provider interrupted", e);
            }
        }

        EmbeddedDatabase database = result.get();
        logger.debug("Database has been successfully fetched in {} - pipelineKey={}", stopwatch, pipeline.key);
        return database;
    }

    protected PrefetchingTask prepareDatabase(PipelineKey key, int priority) {
        DatabasePipeline pipeline = pipelines.get(key);

        if (pipeline.state.get() != INITIALIZED) {
            return prepareExistingDatabase(key, priority)
                    .orElseGet(() -> prepareNewDatabase(key, priority));
        }

        return prepareNewDatabase(key, priority);
    }

    protected PrefetchingTask prepareNewDatabase(PipelineKey key, int priority) {
        databaseCount.incrementAndGet();

        Entry databaseToRemove = findDatabaseToRemove().orElse(null);
        if (databaseToRemove != null) {
            databaseCount.decrementAndGet();

            if (databaseToRemove.getKey().equals(key)) {
                return executeTask(key, PrefetchingTask.withDatabase(databaseToRemove.getValue(), priority));
            } else {
                databaseToRemove.getValue().close();
                DatabasePipeline pipeline = pipelines.get(databaseToRemove.getKey());
                logger.trace("Prepared database has been cleaned: {}", pipeline.key);
            }
        }

        return executeTask(key, PrefetchingTask.forPreparer(key.provider, key.preparer, priority));
    }

    protected Optional prepareExistingDatabase(PipelineKey key, int priority) {
        CompositeDatabasePreparer compositePreparer = key.preparer instanceof CompositeDatabasePreparer ?
                (CompositeDatabasePreparer) key.preparer : new CompositeDatabasePreparer(ImmutableList.of(key.preparer));
        List preparers = compositePreparer.getPreparers();

        for (int i = preparers.size() - 1; i > 0; i--) {
            CompositeDatabasePreparer pipelinePreparer = new CompositeDatabasePreparer(preparers.subList(0, i));
            PipelineKey pipelineKey = new PipelineKey(provider, pipelinePreparer);
            DatabasePipeline existingPipeline = pipelines.get(pipelineKey);

            if (existingPipeline != null) {
                if (key.preparer.estimatedDuration() - pipelinePreparer.estimatedDuration() > 600) {
                    return Optional.empty();
                }

                PreparedResult result = existingPipeline.results.poll();
                if (result != null) {
                    CompositeDatabasePreparer complementaryPreparer = new CompositeDatabasePreparer(preparers.subList(i, preparers.size()));
                    logger.trace("Preparing existing database from {} pipeline by using the complementary preparer {}", existingPipeline.key, complementaryPreparer);
                    PrefetchingTask task = executeTask(key, PrefetchingTask.withDatabase(result.get(), complementaryPreparer, priority));

                    prepareDatabase(pipelineKey, LOWEST_PRECEDENCE);
                    reschedulePipeline(pipelineKey);

                    return Optional.of(task);
                }
            }
        }

        return Optional.empty();
    }

    protected void reschedulePipeline(PipelineKey key) {
        DatabasePipeline pipeline = pipelines.get(key);

        synchronized (pipeline.tasks) {
            long invocationCount = pipeline.requests.get();

            List cancelledTasks = pipeline.tasks.stream()
                    .filter(t -> t.priority > HIGHEST_PRECEDENCE)
                    .filter(t -> t.cancel(false))
                    .collect(toList());

            for (int i = 0; i < cancelledTasks.size(); i++) {
                int priority = -1 * (int) (invocationCount / cancelledTasks.size() * (i + 1));
                executeTask(key, PrefetchingTask.fromTask(cancelledTasks.get(i), priority));
            }
        }
    }

    protected PrefetchingTask executeTask(PipelineKey key, PrefetchingTask task) {
        DatabasePipeline pipeline = pipelines.get(key);

        task.addCallback(new ListenableFutureCallback() {
            @Override
            public void onSuccess(EmbeddedDatabase result) {
                if (task.type == NEW_DATABASE) {
                    pipeline.state.set(INITIALIZED);
                }
                pipeline.tasks.remove(task);
                pipeline.results.offer(PreparedResult.success(result));
            }

            @Override
            public void onFailure(Throwable error) {
                pipeline.tasks.remove(task);
                if (!(error instanceof CancellationException)) {
                    pipeline.results.offer(PreparedResult.failure(error));
                }
            }
        });

        pipeline.tasks.add(task);
        taskExecutor.execute(task);
        return task;
    }

    protected Optional> findDatabaseToRemove() {
        while (databaseCount.get() > config.getMaxPreparedDatabases()) {
            long timestampThreshold = System.currentTimeMillis() - 10_000;

            PipelineKey key = pipelines.entrySet().stream()
                    .map(e -> immutableEntry(e.getKey(), e.getValue().results.peek()))
                    .filter(e -> e.getValue() != null && e.getValue().getTimestamp() < timestampThreshold)
                    .min(Comparator.comparing(e -> e.getValue().getTimestamp()))
                    .map(Entry::getKey).orElse(null);

            if (key == null) {
                return Optional.empty();
            }

            DatabasePipeline pipeline = pipelines.get(key);
            if (pipeline != null) {
                PreparedResult result = pipeline.results.poll();
                if (result != null) {
                    if (result.hasResult()) {
                        return Optional.of(immutableEntry(key, result.get()));
                    } else {
                        databaseCount.decrementAndGet();
                    }
                }
            }
        }
        return Optional.empty();
    }

    protected static class PipelineKey {

        public final DatabaseProvider provider;
        public final DatabasePreparer preparer;

        protected PipelineKey(DatabaseProvider provider, DatabasePreparer preparer) {
            this.provider = provider;
            this.preparer = preparer;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            PipelineKey that = (PipelineKey) o;
            return Objects.equals(provider, that.provider) &&
                    Objects.equals(preparer, that.preparer);
        }

        @Override
        public int hashCode() {
            return Objects.hash(provider, preparer);
        }
    }

    protected static class DatabasePipeline {

        public final String key = RandomStringUtils.randomAlphabetic(8);
        public final AtomicReference state = new AtomicReference<>(NEW);
        public final AtomicLong requests = new AtomicLong();
        public final Set tasks = newSetFromMap(new ConcurrentHashMap<>());
        public final BlockingQueue results = new LinkedBlockingQueue<>();

        @Override
        public String toString() {
            return MoreObjects.toStringHelper(this)
                    .add("pipelineKey", key)
                    .add("pipelineState", state.get())
                    .add("totalRequests", requests.get())
                    .add("prefetchingQueue", tasks.size())
                    .add("preparedResults", results.size())
                    .toString();
        }

        protected enum State {

            NEW, INITIALIZING, INITIALIZED

        }
    }

    protected static class PreparedResult {

        private final long timestamp = System.currentTimeMillis();
        private final EmbeddedDatabase result;
        private final Throwable error;

        public static PreparedResult success(EmbeddedDatabase result) {
            return new PreparedResult(result, null);
        }

        public static PreparedResult failure(Throwable error) {
            return new PreparedResult(null, error);
        }

        protected PreparedResult(EmbeddedDatabase result, Throwable error) {
            this.result = result;
            this.error = error;
        }

        public long getTimestamp() {
            return timestamp;
        }

        public boolean hasResult() {
            return result != null;
        }

        public EmbeddedDatabase get() throws ProviderException {
            if (result != null) {
                return result;
            }
            Throwables.throwIfInstanceOf(error, ProviderException.class);
            throw new ProviderException("Unexpected error when prefetching a database", error);
        }
    }

    protected static class PriorityThreadPoolTaskExecutor extends ThreadPoolTaskExecutor {

        @Override
        protected BlockingQueue createQueue(int queueCapacity) {
            return new PriorityBlockingQueue<>();
        }
    }

    protected static class PrefetchingTask extends ListenableFutureTask implements Comparable {

        private final AtomicBoolean executed = new AtomicBoolean(false);

        public final Callable action;
        public final TaskType type;
        public final int priority;

        public static PrefetchingTask forPreparer(DatabaseProvider provider, DatabasePreparer preparer, int priority) {
            return new PrefetchingTask(priority, NEW_DATABASE, () -> provider.createDatabase(preparer));
        }

        public static PrefetchingTask withDatabase(EmbeddedDatabase database, DatabasePreparer preparer, int priority) {
            return new PrefetchingTask(priority, EXISTING_DATABASE, () -> {
                preparer.prepare(database);
                return database;
            });
        }

        public static PrefetchingTask withDatabase(EmbeddedDatabase database, int priority) {
            return new PrefetchingTask(priority, EXISTING_DATABASE, () -> database);
        }

        public static PrefetchingTask fromTask(PrefetchingTask task, int priority) {
            return new PrefetchingTask(priority, task.type, task.action);
        }

        private PrefetchingTask(int priority, TaskType type, Callable action) {
            super(action);

            this.action = action;
            this.type = type;
            this.priority = priority;
        }

        @Override
        public void run() {
            if (executed.compareAndSet(false, true)) {
                super.run();
            }
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            if (mayInterruptIfRunning || executed.compareAndSet(false, true)) {
                return super.cancel(mayInterruptIfRunning);
            } else {
                return false;
            }
        }

        @Override
        public int compareTo(PrefetchingTask task) {
            return Integer.compare(priority, task.priority);
        }

        protected enum TaskType {

            NEW_DATABASE, EXISTING_DATABASE

        }
    }

    public static class Config {

        private final String threadNamePrefix;
        private final int concurrency;
        private final int pipelineMaxCacheSize;
        private final int maxPreparedDatabases;

        private Config(Config.Builder builder) {
            this.threadNamePrefix = builder.threadNamePrefix;
            this.concurrency = builder.concurrency;
            this.pipelineMaxCacheSize = builder.pipelineMaxCacheSize;
            this.maxPreparedDatabases = builder.maxPreparedDatabases;
        }

        public String getThreadNamePrefix() {
            return threadNamePrefix;
        }

        public int getConcurrency() {
            return concurrency;
        }

        public int getPipelineMaxCacheSize() {
            return pipelineMaxCacheSize;
        }

        public int getMaxPreparedDatabases() {
            return maxPreparedDatabases;
        }

        public static Builder builder() {
            return new Builder();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            Config config = (Config) o;
            return pipelineMaxCacheSize == config.pipelineMaxCacheSize;
        }

        @Override
        public int hashCode() {
            return Objects.hash(pipelineMaxCacheSize);
        }

        public static class Builder {

            private String threadNamePrefix = "prefetching-";
            private int concurrency = 3;
            private int pipelineMaxCacheSize = 5;
            private int maxPreparedDatabases = 15;

            private Builder() {}

            public Builder withThreadNamePrefix(String threadNamePrefix) {
                this.threadNamePrefix = threadNamePrefix;
                return this;
            }

            public Builder withConcurrency(int concurrency) {
                this.concurrency = concurrency;
                return this;
            }

            public Builder withPipelineMaxCacheSize(int pipelineMaxCacheSize) {
                this.pipelineMaxCacheSize = pipelineMaxCacheSize;
                return this;
            }

            public Builder withMaxPreparedDatabases(int maxPreparedDatabases) {
                this.maxPreparedDatabases = maxPreparedDatabases;
                return this;
            }

            public Config build() {
                return new Config(this);
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy