com.aerospike.vector.client.internal.VectorChannelProvider Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of avs-client-java Show documentation
Show all versions of avs-client-java Show documentation
This project includes the Java client for Aerospike Vector Search for high-performance data interactions.
The newest version!
package com.aerospike.vector.client.internal;
import com.aerospike.vector.client.ClientTlsConfig;
import com.aerospike.vector.client.ConnectionConfig;
import com.aerospike.vector.client.proto.ServerEndpoint;
import com.aerospike.vector.client.proto.ServerEndpointList;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.RemovalListener;
import com.github.benmanes.caffeine.cache.Scheduler;
import io.grpc.ManagedChannel;
import io.grpc.NameResolver;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLException;
import java.io.Closeable;
import java.io.File;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
class VectorChannelProvider implements Closeable {
private static final Logger log = LoggerFactory.getLogger(VectorChannelProvider.class);
private final ConnectionConfig connectionConfig;
private final Class extends SocketChannel> grpcChannelType;
private final EventLoopGroup eventLoopGroup;
private final Cache endpointsCache = Caffeine.newBuilder()
.expireAfterAccess(Duration.ofMinutes(100))
.scheduler(Scheduler.systemScheduler())
.removalListener((RemovalListener) (key, value, cause) -> {
if (value != null) {
shutdownManagedChannel(value);
}
})
.build();
private final ThreadFactory eventLoopThreadFactory;
private final String applicationName;
public VectorChannelProvider(ConnectionConfig connectionConfig, String applicationName) {
this.connectionConfig = connectionConfig;
this.applicationName = applicationName;
String poolName = String.format("avs-%s-channel-elg-", applicationName);
eventLoopThreadFactory = new DefaultThreadFactory(poolName, true);
if (Epoll.isAvailable()) {
this.eventLoopGroup = new EpollEventLoopGroup(0, eventLoopThreadFactory);
this.grpcChannelType = EpollSocketChannel.class;
} else {
this.eventLoopGroup = new NioEventLoopGroup(0, eventLoopThreadFactory);
this.grpcChannelType = NioSocketChannel.class;
}
}
public ManagedChannel channelFor(ServerEndpointList serverAddresses) {
ManagedChannel channel = endpointsCache.getIfPresent(serverAddresses);
if (channel != null) {
return channel;
}
// Locking on the cache object to ensure thread safety
synchronized (endpointsCache) {
ManagedChannel newChannel = endpointsCache.getIfPresent(serverAddresses);
if (newChannel != null) {
return newChannel;
}
log.debug("{} Creating new channel for {}", applicationName ,serverAddresses);
newChannel = createGrpcChannel(serverAddresses);
endpointsCache.put(serverAddresses, newChannel);
return newChannel;
}
}
private ManagedChannel createGrpcChannel(ServerEndpointList serverAddresses) {
List tlsHosts = new ArrayList<>();
List plainTextHosts = new ArrayList<>();
for (ServerEndpoint endpoint : serverAddresses.getEndpointsList()) {
if (endpoint.getIsTls()) {
tlsHosts.add(endpoint);
} else {
plainTextHosts.add(endpoint);
}
}
List hosts = !tlsHosts.isEmpty() ? tlsHosts : plainTextHosts;
NettyChannelBuilder builder;
if (hosts.size() == 1) {
builder = NettyChannelBuilder.forAddress(hosts.getFirst().getAddress(), hosts.getFirst().getPort());
} else {
// Setup round-robin load balancing.
List addresses = new ArrayList<>();
for (ServerEndpoint endpoint : hosts) {
addresses.add(new InetSocketAddress(endpoint.getAddress(), endpoint.getPort()));
}
NameResolver.Factory nameResolverFactory = new MultiAddressNameResolverFactory(addresses);
builder = NettyChannelBuilder.forTarget(String.format("%s:%d", hosts.getFirst().getAddress(), hosts.getFirst().getPort()));
builder.nameResolverFactory(nameResolverFactory);
builder.defaultLoadBalancingPolicy("pick_first");
}
builder.eventLoopGroup(eventLoopGroup)
.channelType(grpcChannelType)
.perRpcBufferLimit(128 * 1024 * 1024L)
.negotiationType(NegotiationType.PLAINTEXT)
.maxInboundMessageSize(128 * 1024 * 1024)
.directExecutor()
.disableRetry()
.flowControlWindow(2 * 1024 * 1024)
.keepAliveWithoutCalls(true)
.keepAliveTime(25, TimeUnit.SECONDS)
.keepAliveTimeout(1, TimeUnit.MINUTES);
if (connectionConfig.getClientTlsConfig() != null) {
ClientTlsConfig config = connectionConfig.getClientTlsConfig();
SslContext sslContext;
try {
sslContext = buildSslContext(config.getRootCertificate(), config.getPrivateKey(), config.getCertificateChain());
} catch (SSLException e) {
throw new RuntimeException(e);
}
builder.sslContext(sslContext);
builder.negotiationType(NegotiationType.TLS);
} else {
builder.usePlaintext();
}
builder.withOption(ChannelOption.SO_SNDBUF, 1048576);
builder.withOption(ChannelOption.SO_RCVBUF, 1048576);
builder.withOption(ChannelOption.TCP_NODELAY, true);
if (connectionConfig.getConnectTimeout() != Integer.MAX_VALUE) {
builder.withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionConfig.getConnectTimeout());
}
builder.withOption(ChannelOption.WRITE_BUFFER_WATER_MARK,
new WriteBufferWaterMark(32 * 1024, 64 * 1024));
return builder.build();
}
private SslContext buildSslContext(String rootCertPath, String privateKeyPath, String certChainPath) throws SSLException {
try {
SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
if (rootCertPath != null) {
sslContextBuilder.trustManager(new File(rootCertPath));
}
if (privateKeyPath != null && certChainPath != null) {
sslContextBuilder.keyManager(new File(certChainPath), new File(privateKeyPath));
}
return sslContextBuilder.build();
} catch (Exception e) {
throw new SSLException("Failed to build SSL context", e);
}
}
@Override
public void close() {
endpointsCache.asMap().forEach((k, v) -> {
shutdownManagedChannel(v);
});
endpointsCache.cleanUp();
shutdownEventLoopGroup(eventLoopGroup);
}
private void shutdownManagedChannel(ManagedChannel managedChannel) {
// Close the gRPC managed-channel if not shut down already.
if (!managedChannel.isShutdown()) {
try {
managedChannel.shutdown();
if (!managedChannel.awaitTermination(1, TimeUnit.MINUTES)) {
log.warn("Timed out gracefully shutting down connection: {}", managedChannel);
}
} catch (Exception e) {
log.error("Unexpected exception while waiting for channel termination: {}", managedChannel, e);
}
}
// Forceful shut down if still not terminated.
if (!managedChannel.isTerminated()) {
try {
managedChannel.shutdownNow();
if (!managedChannel.awaitTermination(15, TimeUnit.SECONDS)) {
log.warn("Timed out forcefully shutting down connection: {}", managedChannel);
}
} catch (Exception e) {
log.error("Unexpected exception while waiting for channel termination: {}", managedChannel, e);
}
}
}
private void shutdownEventLoopGroup(EventLoopGroup eventLoopGroup) {
// Close the event loop group if not shut down already.
if (!eventLoopGroup.isShuttingDown()) {
eventLoopGroup.shutdownGracefully();
}
}
}