All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aerospike.vector.client.internal.VectorChannelProvider Maven / Gradle / Ivy

Go to download

This project includes the Java client for Aerospike Vector Search for high-performance data interactions.

The newest version!
package com.aerospike.vector.client.internal;

import com.aerospike.vector.client.ClientTlsConfig;
import com.aerospike.vector.client.ConnectionConfig;
import com.aerospike.vector.client.proto.ServerEndpoint;
import com.aerospike.vector.client.proto.ServerEndpointList;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.RemovalListener;
import com.github.benmanes.caffeine.cache.Scheduler;
import io.grpc.ManagedChannel;
import io.grpc.NameResolver;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.io.Closeable;
import java.io.File;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class VectorChannelProvider implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(VectorChannelProvider.class);

    private final ConnectionConfig connectionConfig;
    private final Class grpcChannelType;
    private final EventLoopGroup eventLoopGroup;

    private final Cache endpointsCache =
            Caffeine.newBuilder()
                    .expireAfterAccess(Duration.ofMinutes(100))
                    .scheduler(Scheduler.systemScheduler())
                    .removalListener(
                            (RemovalListener)
                                    (key, value, cause) -> {
                                        if (value != null) {
                                            shutdownManagedChannel(value);
                                        }
                                    })
                    .build();

    private final String applicationName;

    public VectorChannelProvider(ConnectionConfig connectionConfig, String applicationName) {
        this.connectionConfig = connectionConfig;
        this.applicationName = applicationName;
        String poolName = String.format("avs-%s-channel-elg-", applicationName);
        ThreadFactory eventLoopThreadFactory = new DefaultThreadFactory(poolName, true);
        if (Epoll.isAvailable()) {
            this.eventLoopGroup = new EpollEventLoopGroup(0, eventLoopThreadFactory);
            this.grpcChannelType = EpollSocketChannel.class;
        } else {
            this.eventLoopGroup = new NioEventLoopGroup(0, eventLoopThreadFactory);
            this.grpcChannelType = NioSocketChannel.class;
        }
    }

    public ManagedChannel channelFor(ServerEndpointList serverAddresses) {
        ManagedChannel channel = endpointsCache.getIfPresent(serverAddresses);
        if (channel != null) {
            return channel;
        }

        // Locking on the cache object to ensure thread safety
        synchronized (endpointsCache) {
            ManagedChannel newChannel = endpointsCache.getIfPresent(serverAddresses);
            if (newChannel != null) {
                return newChannel;
            }

            log.debug("{} Creating new channel for {}.", applicationName, serverAddresses);
            newChannel = createGrpcChannel(serverAddresses);
            endpointsCache.put(serverAddresses, newChannel);
            return newChannel;
        }
    }

    private ManagedChannel createGrpcChannel(ServerEndpointList serverAddresses) {
        List tlsHosts = new ArrayList<>();
        List plainTextHosts = new ArrayList<>();
        for (ServerEndpoint endpoint : serverAddresses.getEndpointsList()) {
            if (endpoint.getIsTls()) {
                tlsHosts.add(endpoint);
            } else {
                plainTextHosts.add(endpoint);
            }
        }

        List hosts = !tlsHosts.isEmpty() ? tlsHosts : plainTextHosts;
        NettyChannelBuilder builder;
        if (hosts.size() == 1) {
            builder =
                    NettyChannelBuilder.forAddress(
                            hosts.get(0).getAddress(), hosts.get(0).getPort());
        } else {
            // Setup round-robin load balancing.
            List addresses = new ArrayList<>();
            for (ServerEndpoint endpoint : hosts) {
                addresses.add(new InetSocketAddress(endpoint.getAddress(), endpoint.getPort()));
            }
            NameResolver.Factory nameResolverFactory =
                    new MultiAddressNameResolverFactory(addresses);
            builder =
                    NettyChannelBuilder.forTarget(
                            String.format(
                                    "%s:%d", hosts.get(0).getAddress(), hosts.get(0).getPort()));
            builder.nameResolverFactory(nameResolverFactory);
            builder.defaultLoadBalancingPolicy("pick_first");
        }

        builder.eventLoopGroup(eventLoopGroup)
                .channelType(grpcChannelType)
                .perRpcBufferLimit(128 * 1024 * 1024L)
                .negotiationType(NegotiationType.PLAINTEXT)
                .maxInboundMessageSize(128 * 1024 * 1024)
                .directExecutor()
                .disableRetry()
                .flowControlWindow(2 * 1024 * 1024)
                .keepAliveWithoutCalls(
                        false) // Set it to true to address resource exhaustion during testing
                .keepAliveTime(25, TimeUnit.SECONDS)
                .keepAliveTimeout(1, TimeUnit.MINUTES);

        if (connectionConfig.getClientTlsConfig() != null) {
            ClientTlsConfig config = connectionConfig.getClientTlsConfig();
            SslContext sslContext;
            try {
                sslContext =
                        buildSslContext(
                                config.getRootCertificate(),
                                config.getPrivateKey(),
                                config.getCertificateChain());
            } catch (SSLException e) {
                throw new RuntimeException(e);
            }
            builder.sslContext(sslContext);
            builder.negotiationType(NegotiationType.TLS);
        } else {
            builder.usePlaintext();
        }

        builder.withOption(ChannelOption.SO_SNDBUF, 1048576);
        builder.withOption(ChannelOption.SO_RCVBUF, 1048576);
        builder.withOption(ChannelOption.TCP_NODELAY, true);

        if (connectionConfig.getConnectTimeout() != Integer.MAX_VALUE) {
            builder.withOption(
                    ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionConfig.getConnectTimeout());
        }

        builder.withOption(
                ChannelOption.WRITE_BUFFER_WATER_MARK,
                new WriteBufferWaterMark(32 * 1024, 64 * 1024));

        return builder.build();
    }

    private SslContext buildSslContext(
            String rootCertPath, String privateKeyPath, String certChainPath) throws SSLException {
        try {
            SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
            if (rootCertPath != null) {
                sslContextBuilder.trustManager(new File(rootCertPath));
            }
            if (privateKeyPath != null && certChainPath != null) {
                sslContextBuilder.keyManager(new File(certChainPath), new File(privateKeyPath));
            }
            return sslContextBuilder.build();
        } catch (Exception e) {
            throw new SSLException("Failed to build SSL context.", e);
        }
    }

    @Override
    public void close() {
        endpointsCache.asMap().forEach((k, v) -> shutdownManagedChannel(v));
        endpointsCache.cleanUp();
        shutdownEventLoopGroup(eventLoopGroup);
    }

    private void shutdownManagedChannel(ManagedChannel managedChannel) {
        // Close the gRPC managed-channel if not shut down already.
        if (!managedChannel.isShutdown()) {
            try {
                managedChannel.shutdown();
                if (!managedChannel.awaitTermination(1, TimeUnit.MINUTES)) {
                    log.warn("Timed out gracefully shutting down connection: {}.", managedChannel);
                }
            } catch (Exception e) {
                log.error(
                        "Unexpected exception while waiting for channel termination: {}.",
                        managedChannel,
                        e);
            }
        }

        // Forceful shut down if still not terminated.
        if (!managedChannel.isTerminated()) {
            try {
                managedChannel.shutdownNow();
                if (!managedChannel.awaitTermination(15, TimeUnit.SECONDS)) {
                    log.warn("Timed out forcefully shutting down connection: {}.", managedChannel);
                }
            } catch (Exception e) {
                log.error(
                        "Unexpected exception while waiting for channel termination: {}.",
                        managedChannel,
                        e);
            }
        }
    }

    private void shutdownEventLoopGroup(EventLoopGroup eventLoopGroup) {
        // Close the event loop group if not shut down already.
        if (!eventLoopGroup.isShuttingDown()) {
            eventLoopGroup.shutdownGracefully();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy