
com.aerospike.vector.client.internal.ChannelProvider Maven / Gradle / Ivy
package com.aerospike.vector.client.internal;
import com.aerospike.vector.client.*;
import com.google.protobuf.Empty;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
/**
* A channel provider for vector db
*/
public class ChannelProvider implements AutoCloseable {
private record ChannelAndEndpoints(Channel channel, List endpoints) {
}
private static final Logger log = LoggerFactory.getLogger(ChannelProvider.class);
private final Map nodeChannels = new ConcurrentHashMap<>();
private final List seedChannels = new ArrayList<>();
private volatile boolean closed = false;
private long clusterId = 0;
private final List seeds;
private final String listenerName;
private final ScheduledExecutorService tendExecutorService = Executors.newSingleThreadScheduledExecutor();
private final boolean isLoadBalancer;
/**
* Constructor
* @param seeds list of {@link HostPort}
* @param listenerName name of the listener
* @param isLoadBalancer If hostport represents a load balancer
*/
public ChannelProvider(List seeds, String listenerName, boolean isLoadBalancer) {
if (seeds == null || seeds.isEmpty()) {
throw new IllegalArgumentException("At least one seed host needed");
}
this.seeds = seeds;
this.listenerName = listenerName;
this.isLoadBalancer = isLoadBalancer;
initializeSeedChannels();
if(!isLoadBalancer) {
tend();
}
}
private void initializeSeedChannels() {
for (HostPort seed : seeds) {
seedChannels.add(createChannelFromHostPort(seed));
}
}
@Override
public void close() {
closed = true;
for (Channel seedChannel : seedChannels) {
((ManagedChannel) seedChannel).shutdown();
}
for (ChannelAndEndpoints channelEndpoints : nodeChannels.values()) {
ManagedChannel channel = (ManagedChannel) channelEndpoints.channel();
if(!channel.isShutdown()){
channel.shutdown();
}
}
tendExecutorService.shutdownNow();
}
/**
* Provides a channel
* @return {@link Channel}
*/
public Channel getChannel() {
if (!isLoadBalancer) {
List discoveredChannels = new ArrayList<>(nodeChannels.values());
if (discoveredChannels.isEmpty()) {
return seedChannels.get(0);
}
// Return a random channel.
ChannelAndEndpoints randomChannelEndpoints =
discoveredChannels.get(new Random().nextInt(discoveredChannels.size()));
Channel channel = randomChannelEndpoints.channel();
if (channel != null) {
return channel;
}
}
return seedChannels.get(0);
}
private void tend() {
tendExecutorService.scheduleWithFixedDelay(this::tendCluster, 0, 1, TimeUnit.SECONDS);
}
private void tendCluster() {
Map tempEndpoints = new HashMap<>();
if (closed) {
return;
}
try {
boolean updateEndpoints = false;
List channels = new ArrayList<>(seedChannels);
nodeChannels.values().forEach(channelEndpoints -> channels.add(channelEndpoints.channel()));
for (Channel seedChannel : channels) {
try {
ClusterInfoGrpc.ClusterInfoBlockingStub stub = ClusterInfoGrpc.newBlockingStub(seedChannel);
long newClusterId = stub.getClusterId(Empty.getDefaultInstance()).getId();
if (newClusterId == clusterId) {
continue;
}
updateEndpoints = true;
clusterId = newClusterId;
Map endpoints = stub.getClusterEndpoints(
ClusterNodeEndpointsRequest.newBuilder()
.setListenerName(listenerName)
.build()
).getEndpoints();
if (endpoints.size() > tempEndpoints.size()) {
tempEndpoints = endpoints;
}
} catch (Exception e) {
log.error("Error in tend thread processing channel: {}, exception: {}", seedChannel, e);
e.printStackTrace();
}
}
if (updateEndpoints) {
for (Map.Entry entry : tempEndpoints.entrySet()) {
long node = entry.getKey();
ServerEndpointList newEndpoints = entry.getValue();
ChannelAndEndpoints channelEndpoints = nodeChannels.get(node);
boolean addNewChannel = true;
if (channelEndpoints != null) {
if (channelEndpoints.endpoints().equals(newEndpoints)) {
addNewChannel = false;
} else {
((ManagedChannel) channelEndpoints.channel()).shutdown();
addNewChannel = true;
}
}
if (addNewChannel) {
Channel newChannel = createChannelFromServerEndpointList(newEndpoints);
nodeChannels.put(node, new ChannelAndEndpoints(newChannel, newEndpoints.getEndpointsList()));
}
}
for (long node : new ArrayList<>(nodeChannels.keySet())) {
if (!tempEndpoints.containsKey(node)) {
ChannelAndEndpoints channelEndpoints = nodeChannels.get(node);
if (channelEndpoints != null) {
((ManagedChannel) channelEndpoints.channel()).shutdown();
}
nodeChannels.remove(node);
}
}
}
} catch (Exception e) {
// Log this exception
log.error("Exception in tend thread", e);
e.printStackTrace();
}
if (!closed) {
tendExecutorService.schedule(this::tend, 1, TimeUnit.SECONDS);
}
}
private Channel createChannelFromHostPort(HostPort hostPort) {
return createChannel(hostPort.address(), hostPort.port(), hostPort.isTls());
}
private Channel createChannelFromServerEndpointList(ServerEndpointList endpoints) {
for (ServerEndpoint endpoint : endpoints.getEndpointsList()) {
if (endpoint.getAddress().contains(":")) {
// Ignoring IPv6 for now, needs fix
continue;
}
try {
return createChannel(endpoint.getAddress(), endpoint.getPort(), endpoint.getIsTls());
} catch (Exception e) {
// Log the exception and continue trying with the next endpoint
e.printStackTrace();
}
}
throw new RuntimeException("Failed to create channel from server endpoint list");
}
private Channel createChannel(String host, int port, boolean isTls) {
// Remove any characters from the host string after a '%' character
//NOT supported tls at the moment!
int percentIndex = host.indexOf('%');
if (percentIndex != -1) {
host = host.substring(0, percentIndex);
}
// Building a gRPC channel
return ManagedChannelBuilder.forAddress(host, port)
.usePlaintext()
.build();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy