All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.driver.internal.bolt.routedimpl.RoutedBoltConnectionProvider Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [https://neo4j.com]
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.neo4j.driver.internal.bolt.routedimpl;

import static java.lang.String.format;

import java.time.Clock;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import org.neo4j.driver.Value;
import org.neo4j.driver.exceptions.SecurityException;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.exceptions.SessionExpiredException;
import org.neo4j.driver.internal.bolt.api.AccessMode;
import org.neo4j.driver.internal.bolt.api.BoltAgent;
import org.neo4j.driver.internal.bolt.api.BoltConnection;
import org.neo4j.driver.internal.bolt.api.BoltConnectionProvider;
import org.neo4j.driver.internal.bolt.api.BoltProtocolVersion;
import org.neo4j.driver.internal.bolt.api.BoltServerAddress;
import org.neo4j.driver.internal.bolt.api.DatabaseName;
import org.neo4j.driver.internal.bolt.api.DatabaseNameUtil;
import org.neo4j.driver.internal.bolt.api.DomainNameResolver;
import org.neo4j.driver.internal.bolt.api.LoggingProvider;
import org.neo4j.driver.internal.bolt.api.MetricsListener;
import org.neo4j.driver.internal.bolt.api.NotificationConfig;
import org.neo4j.driver.internal.bolt.api.RoutingContext;
import org.neo4j.driver.internal.bolt.api.SecurityPlan;
import org.neo4j.driver.internal.bolt.routedimpl.cluster.Rediscovery;
import org.neo4j.driver.internal.bolt.routedimpl.cluster.RediscoveryImpl;
import org.neo4j.driver.internal.bolt.routedimpl.cluster.RoutingTable;
import org.neo4j.driver.internal.bolt.routedimpl.cluster.RoutingTableHandler;
import org.neo4j.driver.internal.bolt.routedimpl.cluster.RoutingTableRegistry;
import org.neo4j.driver.internal.bolt.routedimpl.cluster.RoutingTableRegistryImpl;
import org.neo4j.driver.internal.bolt.routedimpl.cluster.loadbalancing.LeastConnectedLoadBalancingStrategy;
import org.neo4j.driver.internal.bolt.routedimpl.cluster.loadbalancing.LoadBalancingStrategy;
import org.neo4j.driver.internal.bolt.routedimpl.util.FutureUtil;

public class RoutedBoltConnectionProvider implements BoltConnectionProvider {
    private static final String CONNECTION_ACQUISITION_COMPLETION_FAILURE_MESSAGE =
            "Connection acquisition failed for all available addresses.";
    private static final String CONNECTION_ACQUISITION_COMPLETION_EXCEPTION_MESSAGE =
            "Failed to obtain connection towards %s server. Known routing table is: %s";
    private static final String CONNECTION_ACQUISITION_ATTEMPT_FAILURE_MESSAGE =
            "Failed to obtain a connection towards address %s, will try other addresses if available. Complete failure is reported separately from this entry.";
    private final LoggingProvider logging;
    private final System.Logger log;
    private final Supplier boltConnectionProviderSupplier;

    private final Map addressToProvider = new HashMap<>();
    private final Function> resolver;
    private final DomainNameResolver domainNameResolver;
    private final Map addressToInUseCount = new HashMap<>();

    private final LoadBalancingStrategy loadBalancingStrategy;
    private final long routingTablePurgeDelayMs;

    private Rediscovery rediscovery;
    private RoutingTableRegistry registry;

    private RoutingContext routingContext;
    private BoltAgent boltAgent;
    private String userAgent;
    private int connectTimeoutMillis;
    private CompletableFuture closeFuture;
    private final Clock clock;
    private MetricsListener metricsListener;

    public RoutedBoltConnectionProvider(
            Supplier boltConnectionProviderSupplier,
            Function> resolver,
            DomainNameResolver domainNameResolver,
            long routingTablePurgeDelayMs,
            Rediscovery rediscovery,
            Clock clock,
            LoggingProvider logging) {
        this.boltConnectionProviderSupplier = Objects.requireNonNull(boltConnectionProviderSupplier);
        this.resolver = Objects.requireNonNull(resolver);
        this.logging = Objects.requireNonNull(logging);
        this.log = logging.getLog(getClass());
        this.loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(this::getInUseCount, logging);
        this.domainNameResolver = Objects.requireNonNull(domainNameResolver);
        this.routingTablePurgeDelayMs = routingTablePurgeDelayMs;
        this.rediscovery = rediscovery;
        this.clock = Objects.requireNonNull(clock);
    }

    @Override
    public synchronized CompletionStage init(
            BoltServerAddress address,
            RoutingContext routingContext,
            BoltAgent boltAgent,
            String userAgent,
            int connectTimeoutMillis,
            MetricsListener metricsListener) {
        this.routingContext = routingContext;
        this.boltAgent = boltAgent;
        this.userAgent = userAgent;
        this.connectTimeoutMillis = connectTimeoutMillis;
        if (this.rediscovery == null) {
            this.rediscovery = new RediscoveryImpl(address, resolver, logging, domainNameResolver);
        }
        this.registry = new RoutingTableRegistryImpl(
                this::get, rediscovery, clock, logging, routingTablePurgeDelayMs, this::shutdownUnusedProviders);
        this.metricsListener = Objects.requireNonNull(metricsListener);

        return CompletableFuture.completedStage(null);
    }

    @Override
    public CompletionStage connect(
            SecurityPlan securityPlan,
            DatabaseName databaseName,
            Supplier>> authMapStageSupplier,
            AccessMode mode,
            Set bookmarks,
            String impersonatedUser,
            BoltProtocolVersion minVersion,
            NotificationConfig notificationConfig,
            Consumer databaseNameConsumer) {
        RoutingTableRegistry registry;
        synchronized (this) {
            if (closeFuture != null) {
                return CompletableFuture.failedFuture(new IllegalStateException("Connection provider is closed."));
            }
            registry = this.registry;
        }

        var handlerRef = new AtomicReference();
        var databaseNameFuture = databaseName == null
                ? new CompletableFuture()
                : CompletableFuture.completedFuture(databaseName);
        databaseNameFuture.whenComplete((name, throwable) -> {
            if (name != null) {
                databaseNameConsumer.accept(name);
            }
        });
        return registry.ensureRoutingTable(
                        securityPlan,
                        databaseNameFuture,
                        mode,
                        bookmarks,
                        impersonatedUser,
                        authMapStageSupplier,
                        minVersion)
                .thenApply(routingTableHandler -> {
                    handlerRef.set(routingTableHandler);
                    return routingTableHandler;
                })
                .thenCompose(routingTableHandler -> acquire(
                        securityPlan,
                        mode,
                        routingTableHandler.routingTable(),
                        authMapStageSupplier,
                        routingTableHandler.routingTable().database(),
                        Set.of(),
                        impersonatedUser,
                        minVersion,
                        notificationConfig))
                .thenApply(boltConnection -> new RoutedBoltConnection(boltConnection, handlerRef.get(), mode, this));
    }

    @Override
    public CompletionStage verifyConnectivity(SecurityPlan securityPlan, Map authMap) {
        RoutingTableRegistry registry;
        synchronized (this) {
            registry = this.registry;
        }
        return supportsMultiDb(securityPlan, authMap)
                .thenCompose(supports -> registry.ensureRoutingTable(
                        securityPlan,
                        supports
                                ? CompletableFuture.completedFuture(DatabaseNameUtil.database("system"))
                                : CompletableFuture.completedFuture(DatabaseNameUtil.defaultDatabase()),
                        AccessMode.READ,
                        Collections.emptySet(),
                        null,
                        () -> CompletableFuture.completedStage(authMap),
                        null))
                .handle((ignored, error) -> {
                    if (error != null) {
                        var cause = FutureUtil.completionExceptionCause(error);
                        if (cause instanceof ServiceUnavailableException) {
                            throw FutureUtil.asCompletionException(new ServiceUnavailableException(
                                    "Unable to connect to database management service, ensure the database is running and that there is a working network connection to it.",
                                    cause));
                        }
                        throw FutureUtil.asCompletionException(cause);
                    }
                    return null;
                });
    }

    @Override
    public CompletionStage supportsMultiDb(SecurityPlan securityPlan, Map authMap) {
        return detectFeature(
                securityPlan,
                authMap,
                "Failed to perform multi-databases feature detection with the following servers: ",
                (boltConnection -> boltConnection.protocolVersion().compareTo(new BoltProtocolVersion(4, 0)) >= 0));
    }

    @Override
    public CompletionStage supportsSessionAuth(SecurityPlan securityPlan, Map authMap) {
        return detectFeature(
                securityPlan,
                authMap,
                "Failed to perform session auth feature detection with the following servers: ",
                (boltConnection -> new BoltProtocolVersion(5, 1).compareTo(boltConnection.protocolVersion()) <= 0));
    }

    private synchronized void shutdownUnusedProviders(Set addressesToRetain) {
        var iterator = addressToProvider.entrySet().iterator();
        while (iterator.hasNext()) {
            var entry = iterator.next();
            var address = entry.getKey();
            if (!addressesToRetain.contains(address) && getInUseCount(address) == 0) {
                entry.getValue().close();
                iterator.remove();
            }
        }
    }

    private CompletionStage detectFeature(
            SecurityPlan securityPlan,
            Map authMap,
            String baseErrorMessagePrefix,
            Function featureDetectionFunction) {
        Rediscovery rediscovery;
        synchronized (this) {
            rediscovery = this.rediscovery;
        }

        List addresses;
        try {
            addresses = rediscovery.resolve();
        } catch (Throwable error) {
            return CompletableFuture.failedFuture(error);
        }
        CompletableFuture result = CompletableFuture.completedFuture(null);
        Throwable baseError = new ServiceUnavailableException(baseErrorMessagePrefix + addresses);

        for (var address : addresses) {
            result = FutureUtil.onErrorContinue(result, baseError, completionError -> {
                // We fail fast on security errors
                var error = FutureUtil.completionExceptionCause(completionError);
                if (error instanceof SecurityException) {
                    return CompletableFuture.failedFuture(error);
                }
                return get(address)
                        .connect(
                                securityPlan,
                                null,
                                () -> CompletableFuture.completedStage(authMap),
                                AccessMode.WRITE,
                                Collections.emptySet(),
                                null,
                                null,
                                null,
                                (ignored) -> {})
                        .thenCompose(boltConnection -> {
                            var featureDetected = featureDetectionFunction.apply(boltConnection);
                            return boltConnection.close().thenApply(ignored -> featureDetected);
                        });
            });
        }
        return FutureUtil.onErrorContinue(result, baseError, completionError -> {
            // If we failed with security errors, then we rethrow the security error out, otherwise we throw the chained
            // errors.
            var error = FutureUtil.completionExceptionCause(completionError);
            if (error instanceof SecurityException) {
                return CompletableFuture.failedFuture(error);
            }
            return CompletableFuture.failedFuture(baseError);
        });
    }

    private CompletionStage acquire(
            SecurityPlan securityPlan,
            AccessMode mode,
            RoutingTable routingTable,
            Supplier>> authMapStageSupplier,
            DatabaseName database,
            Set bookmarks,
            String impersonatedUser,
            BoltProtocolVersion minVersion,
            NotificationConfig notificationConfig) {
        var result = new CompletableFuture();
        List attemptExceptions = new ArrayList<>();
        acquire(
                securityPlan,
                mode,
                routingTable,
                result,
                authMapStageSupplier,
                attemptExceptions,
                database,
                bookmarks,
                impersonatedUser,
                minVersion,
                notificationConfig);
        return result;
    }

    private void acquire(
            SecurityPlan securityPlan,
            AccessMode mode,
            RoutingTable routingTable,
            CompletableFuture result,
            Supplier>> authMapStageSupplier,
            List attemptErrors,
            DatabaseName database,
            Set bookmarks,
            String impersonatedUser,
            BoltProtocolVersion minVersion,
            NotificationConfig notificationConfig) {
        var addresses = getAddressesByMode(mode, routingTable);
        log.log(System.Logger.Level.DEBUG, "Addresses: " + addresses);
        var address = selectAddress(mode, addresses);
        log.log(System.Logger.Level.DEBUG, "Selected address: " + address);

        if (address == null) {
            var completionError = new SessionExpiredException(
                    format(CONNECTION_ACQUISITION_COMPLETION_EXCEPTION_MESSAGE, mode, routingTable));
            attemptErrors.forEach(completionError::addSuppressed);
            log.log(System.Logger.Level.ERROR, CONNECTION_ACQUISITION_COMPLETION_FAILURE_MESSAGE, completionError);
            result.completeExceptionally(completionError);
            return;
        }

        get(address)
                .connect(
                        securityPlan,
                        database,
                        authMapStageSupplier,
                        mode,
                        bookmarks,
                        impersonatedUser,
                        minVersion,
                        notificationConfig,
                        (ignored) -> {})
                .whenComplete((connection, completionError) -> {
                    var error = FutureUtil.completionExceptionCause(completionError);
                    if (error != null) {
                        if (error instanceof ServiceUnavailableException) {
                            var attemptMessage = format(CONNECTION_ACQUISITION_ATTEMPT_FAILURE_MESSAGE, address);
                            log.log(System.Logger.Level.WARNING, attemptMessage);
                            log.log(System.Logger.Level.DEBUG, attemptMessage, error);
                            attemptErrors.add(error);
                            routingTable.forget(address);
                            CompletableFuture.runAsync(() -> acquire(
                                    securityPlan,
                                    mode,
                                    routingTable,
                                    result,
                                    authMapStageSupplier,
                                    attemptErrors,
                                    database,
                                    bookmarks,
                                    impersonatedUser,
                                    minVersion,
                                    notificationConfig));
                        } else {
                            result.completeExceptionally(error);
                        }
                    } else {
                        incrementInUseCount(address);
                        result.complete(connection);
                    }
                });
    }

    private BoltServerAddress selectAddress(AccessMode mode, List addresses) {
        return switch (mode) {
            case READ -> loadBalancingStrategy.selectReader(addresses);
            case WRITE -> loadBalancingStrategy.selectWriter(addresses);
        };
    }

    private static List getAddressesByMode(AccessMode mode, RoutingTable routingTable) {
        return switch (mode) {
            case READ -> routingTable.readers();
            case WRITE -> routingTable.writers();
        };
    }

    private synchronized int getInUseCount(BoltServerAddress address) {
        return addressToInUseCount.getOrDefault(address, 0);
    }

    private synchronized void incrementInUseCount(BoltServerAddress address) {
        addressToInUseCount.merge(address, 1, Integer::sum);
    }

    synchronized void decrementInUseCount(BoltServerAddress address) {
        addressToInUseCount.compute(address, (ignored, value) -> {
            if (value == null) {
                return null;
            } else {
                value--;
                return value > 0 ? value : null;
            }
        });
    }

    @Override
    public CompletionStage close() {
        CompletableFuture closeFuture;
        synchronized (this) {
            if (this.closeFuture == null) {
                @SuppressWarnings({"rawtypes", "RedundantSuppression"})
                var futures = new CompletableFuture[addressToProvider.size()];
                var iterator = addressToProvider.values().iterator();
                var index = 0;
                while (iterator.hasNext()) {
                    futures[index++] = iterator.next().close().toCompletableFuture();
                    iterator.remove();
                }
                this.closeFuture = CompletableFuture.allOf(futures);
            }
            closeFuture = this.closeFuture;
        }
        return closeFuture;
    }

    private synchronized BoltConnectionProvider get(BoltServerAddress address) {
        var provider = addressToProvider.get(address);
        if (provider == null) {
            provider = boltConnectionProviderSupplier.get();
            provider.init(address, routingContext, boltAgent, userAgent, connectTimeoutMillis, metricsListener);
            addressToProvider.put(address, provider);
        }
        return provider;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy