All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aerospike.vector.client.internal.ChannelProvider Maven / Gradle / Ivy

package com.aerospike.vector.client.internal;

import com.aerospike.vector.client.*;
import com.google.protobuf.Empty;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;

/**
 * A channel provider for vector db
 */
public class ChannelProvider implements AutoCloseable {
    private record ChannelAndEndpoints(Channel channel, List endpoints) {
    }
    private static final Logger log = LoggerFactory.getLogger(ChannelProvider.class);

    private final Map nodeChannels = new ConcurrentHashMap<>();
    private final List seedChannels = new ArrayList<>();
    private volatile boolean closed = false;
    private long clusterId = 0;
    private final List seeds;
    private final String listenerName;
    private final ScheduledExecutorService tendExecutorService = Executors.newSingleThreadScheduledExecutor();
    private final boolean isLoadBalancer;

    /**
     * Constructor
     * @param seeds list of {@link HostPort}
     * @param listenerName name of the listener
     * @param isLoadBalancer If hostport represents a load balancer
     */
    public ChannelProvider(List seeds, String listenerName, boolean isLoadBalancer) {
        if (seeds == null || seeds.isEmpty()) {
            throw new IllegalArgumentException("At least one seed host needed");
        }
        this.seeds = seeds;
        this.listenerName = listenerName;
        this.isLoadBalancer = isLoadBalancer;
        initializeSeedChannels();
        if(!isLoadBalancer) {
            tend();
        }

    }

    private void initializeSeedChannels() {
        for (HostPort seed : seeds) {
            seedChannels.add(createChannelFromHostPort(seed));
        }
    }

    @Override
    public void close() {
        closed = true;
        for (Channel seedChannel : seedChannels) {
            ((ManagedChannel) seedChannel).shutdown();
        }

        for (ChannelAndEndpoints channelEndpoints : nodeChannels.values()) {
           ManagedChannel channel = (ManagedChannel) channelEndpoints.channel();
           if(!channel.isShutdown()){
               channel.shutdown();
           }
        }

        tendExecutorService.shutdownNow();
    }

    /**
     * Provides a channel
     * @return {@link Channel}
     */
    public Channel getChannel() {
        if (!isLoadBalancer) {
            List discoveredChannels = new ArrayList<>(nodeChannels.values());
            if (discoveredChannels.isEmpty()) {
                return seedChannels.get(0);
            }

            // Return a random channel.
            ChannelAndEndpoints randomChannelEndpoints =
                    discoveredChannels.get(new Random().nextInt(discoveredChannels.size()));
            Channel channel = randomChannelEndpoints.channel();
            if (channel != null) {
                return channel;
            }
        }

        return seedChannels.get(0);
    }

    private void tend() {
        tendExecutorService.scheduleWithFixedDelay(this::tendCluster, 0, 1, TimeUnit.SECONDS);
    }

    private void tendCluster() {

        Map tempEndpoints = new HashMap<>();

        if (closed) {
            return;
        }

        try {
            boolean updateEndpoints = false;
            List channels = new ArrayList<>(seedChannels);
            nodeChannels.values().forEach(channelEndpoints -> channels.add(channelEndpoints.channel()));

            for (Channel seedChannel : channels) {
                try {
                    ClusterInfoGrpc.ClusterInfoBlockingStub stub = ClusterInfoGrpc.newBlockingStub(seedChannel);
                    long newClusterId = stub.getClusterId(Empty.getDefaultInstance()).getId();

                    if (newClusterId == clusterId) {
                        continue;
                    }

                    updateEndpoints = true;
                    clusterId = newClusterId;
                    Map endpoints = stub.getClusterEndpoints(
                            ClusterNodeEndpointsRequest.newBuilder()
                                    .setListenerName(listenerName)
                                    .build()
                    ).getEndpoints();

                    if (endpoints.size() > tempEndpoints.size()) {
                        tempEndpoints = endpoints;
                    }
                } catch (Exception e) {
                    log.error("Error in tend thread processing channel: {}, exception: {}", seedChannel, e);
                    e.printStackTrace();
                }
            }

            if (updateEndpoints) {
                for (Map.Entry entry : tempEndpoints.entrySet()) {
                    long node = entry.getKey();
                    ServerEndpointList newEndpoints = entry.getValue();
                    ChannelAndEndpoints channelEndpoints = nodeChannels.get(node);
                    boolean addNewChannel = true;

                    if (channelEndpoints != null) {
                        if (channelEndpoints.endpoints().equals(newEndpoints)) {
                            addNewChannel = false;
                        } else {
                            ((ManagedChannel) channelEndpoints.channel()).shutdown();
                            addNewChannel = true;
                        }
                    }

                    if (addNewChannel) {
                        Channel newChannel = createChannelFromServerEndpointList(newEndpoints);
                        nodeChannels.put(node, new ChannelAndEndpoints(newChannel, newEndpoints.getEndpointsList()));
                    }
                }

                for (long node : new ArrayList<>(nodeChannels.keySet())) {
                    if (!tempEndpoints.containsKey(node)) {
                        ChannelAndEndpoints channelEndpoints = nodeChannels.get(node);
                        if (channelEndpoints != null) {
                            ((ManagedChannel) channelEndpoints.channel()).shutdown();
                        }
                        nodeChannels.remove(node);
                    }
                }
            }
        } catch (Exception e) {
            // Log this exception
            log.error("Exception in tend thread", e);
            e.printStackTrace();
        }

        if (!closed) {
            tendExecutorService.schedule(this::tend, 1, TimeUnit.SECONDS);
        }
    }


    private Channel createChannelFromHostPort(HostPort hostPort) {
        return createChannel(hostPort.address(), hostPort.port(), hostPort.isTls());
    }

    private Channel createChannelFromServerEndpointList(ServerEndpointList endpoints) {
        for (ServerEndpoint endpoint : endpoints.getEndpointsList()) {
            if (endpoint.getAddress().contains(":")) {
                // Ignoring IPv6 for now, needs fix
                continue;
            }
            try {
                return createChannel(endpoint.getAddress(), endpoint.getPort(), endpoint.getIsTls());
            } catch (Exception e) {
                // Log the exception and continue trying with the next endpoint
                e.printStackTrace();
            }
        }
        throw new RuntimeException("Failed to create channel from server endpoint list");
    }


    private Channel createChannel(String host, int port, boolean isTls) {
        // Remove any characters from the host string after a '%' character
        //NOT supported tls at the moment!
        int percentIndex = host.indexOf('%');
        if (percentIndex != -1) {
            host = host.substring(0, percentIndex);
        }

        // Building a gRPC channel
        return ManagedChannelBuilder.forAddress(host, port)
                .usePlaintext()
                .build();
    }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy