All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aerospike.vector.client.internal.ClusterTenderer Maven / Gradle / Ivy

Go to download

This project includes the Java client for Aerospike Vector Search for high-performance data interactions.

The newest version!
package com.aerospike.vector.client.internal;

import com.aerospike.vector.client.ConnectionConfig;
import com.aerospike.vector.client.HostPort;
import com.aerospike.vector.client.auth.PasswordCredentials;
import com.aerospike.vector.client.internal.auth.AuthTokenManager;
import com.aerospike.vector.client.proto.*;
import com.google.protobuf.Empty;
import io.grpc.Channel;
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.stub.AbstractStub;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;


/**
 * Internal class for providing gRPC tls/non-tls channels
 */

public class ClusterTenderer implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(ClusterTenderer.class);
    private final ConnectionConfig connectionConfig;
    private final List seedChannels = new ArrayList<>();
    private final ScheduledExecutorService tendExecutorService = Executors.newSingleThreadScheduledExecutor();

    private AuthTokenManager authTokenManager;
    private Exception lastTendException = null;
    private final VectorChannelProvider vectorChannelProvider;
    private final AtomicReference defaultClusterid = new AtomicReference<>(ClusterId.getDefaultInstance());
    private final AtomicReference> nodeEndpoints = new  AtomicReference<>(new HashMap<>());
    private final AtomicBoolean closed = new AtomicBoolean();
    private final Random random = new Random();
    ExecutorService executor = Executors.newSingleThreadExecutor();

    public final String identifier;

    public ClusterTenderer(ConnectionConfig connectionConfig, String identifier) {
        this.identifier = identifier;
        this.connectionConfig = connectionConfig;
        vectorChannelProvider = new VectorChannelProvider(connectionConfig, identifier);

        if(!connectionConfig.isLoadBalancer()) {
            tendExecutorService.scheduleWithFixedDelay(this::refreshClusterState, 0, 1, TimeUnit.SECONDS);
        }

        if (connectionConfig.getCredentials() instanceof PasswordCredentials) {
            this.authTokenManager = new AuthTokenManager((PasswordCredentials) connectionConfig.getCredentials(), this);
        }

        long start = System.currentTimeMillis();
        waitTillReady();
        log.info("{}; waitTillReady took: {} seconds", identifier, (System.currentTimeMillis() - start)/60.0 );

    }

    private ManagedChannel getTendChannel() {
        if (!connectionConfig.isLoadBalancer()) {
            ManagedChannel channel = getChannelFromEndpoints();
            if (channel != null) {
                return channel;
            }
        }
        HostPort seed = connectionConfig.getSeeds().get(random.nextInt(connectionConfig.getSeeds().size()));
        return vectorChannelProvider.channelFor(
                ServerEndpointList.newBuilder().addEndpoints(
                        ServerEndpoint.newBuilder().setAddress(seed.address())
                                .setPort(seed.port())
                                .setIsTls(connectionConfig.getClientTlsConfig() != null)
                ).build()
        );
    }

    private ManagedChannel getChannelFromEndpoints() {
        Map endpoints = nodeEndpoints.get();
        if (endpoints.isEmpty()) {
            return null;
        }

        List valuesList = new ArrayList<>(endpoints.values());
        ServerEndpointList randomEndpoints = valuesList.get(random.nextInt(valuesList.size()));

        if (randomEndpoints.getEndpointsList().isEmpty()) {
            return null;
        }

        return vectorChannelProvider.channelFor(randomEndpoints);
    }

    private void refreshClusterState() {
        try {
            HashMap tempEndpoints = new HashMap<>();
            boolean updateEndpoints = false;

            for (HostPort seed : connectionConfig.getSeeds()) {
                ManagedChannel tendChannel = vectorChannelProvider.channelFor(
                        ServerEndpointList.newBuilder().addEndpoints(
                                ServerEndpoint.newBuilder().setAddress(seed.address())
                                        .setPort(seed.port())
                                        .setIsTls(connectionConfig.getClientTlsConfig() != null)
                        ).build()
                );

                ClusterInfoServiceGrpc.ClusterInfoServiceBlockingStub clusterInfoStub = clusterInfoBlockingStub(tendChannel);
                ClusterId newClusterId = clusterInfoStub.getClusterId(Empty.getDefaultInstance());

                if (defaultClusterid.get().equals(newClusterId)) {
                    continue;  // Skip to the next iteration in the loop, equivalent to return@forEach in Kotlin
                }

                updateEndpoints = true;
                defaultClusterid.set(newClusterId);

                ClusterNodeEndpointsRequest.Builder requestBuilder = ClusterNodeEndpointsRequest.newBuilder();
                if (connectionConfig.getListenerName() != null) {
                    requestBuilder.setListenerName(connectionConfig.getListenerName());
                }

                ClusterNodeEndpoints endpoints = clusterInfoStub.getClusterEndpoints(requestBuilder.build());
                if (endpoints.getEndpointsMap().size() > tempEndpoints.size()) {
                    tempEndpoints = new HashMap<>(endpoints.getEndpointsMap());
                }
            }

            if (updateEndpoints) {
                nodeEndpoints.set(tempEndpoints);
            }
        } catch (Exception e) {
            lastTendException = e;
            log.debug("Error getting node endpoints.", e);
        }
    }

    public ManagedChannel getChannel() {
        if (connectionConfig.isLoadBalancer()) {
            return getTendChannel();
        }

        long t1 = System.currentTimeMillis();
        ManagedChannel channel = getChannelFromEndpoints();
        while (channel !=null &&  (System.currentTimeMillis() - t1) < connectionConfig.getConnectTimeout() ) {
            try {
                TimeUnit.MILLISECONDS.sleep(100);
                channel = getChannelFromEndpoints();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }

        //TBD need to investigate why channel comes as null instead of
//        log.info("channel was {}", channel);
        return channel != null ? channel : getTendChannel();
    }


    @Override
    public void close() {
        closed.set(true);
        for (Channel seedChannel : seedChannels) {
            ((ManagedChannel) seedChannel).shutdown();
        }
        authTokenManager.close();
        tendExecutorService.shutdownNow();
        vectorChannelProvider.close();
        executor.close();
    }

    private void waitTillReady() {

        log.info("{}: waiting for auth-manager to get ready", identifier);
        long timeout = connectionConfig.getConnectTimeout();
        long deadline = System.currentTimeMillis() + timeout;
        boolean notReady = !isReady();
        boolean notPastWaitTime = System.currentTimeMillis() < deadline;
        while (notReady && notPastWaitTime) {
            try {
                log.debug("{}; notReady:{}, notPastWaitTime:{}" , identifier, notReady, notPastWaitTime);
                Thread.sleep(100);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException("Interrupted while waiting", e);
            }finally {

                notReady = !isReady();
                notPastWaitTime = System.currentTimeMillis() < deadline;
            }
        }

        if (!isReady()) {
            Exception augmentedException = new TimeoutException("Connect timed out after " + timeout + " ms");
            if (lastTendException != null) {
                augmentedException.initCause(lastTendException);
            }

            if (connectionConfig.isFailIfNotConnected()) {
                throw new RuntimeException(augmentedException);
            } else {
                log.warn("Client not connected in {}", identifier, augmentedException);
            }
        }
    }

    private boolean isReady() {
        return  (connectionConfig.isLoadBalancer() || !nodeEndpoints.get().isEmpty())
                && (authTokenManager == null || authTokenManager.getTokenStatus().isValid());
    }

    public TransactServiceGrpc.TransactServiceBlockingStub getTransactBlockingStub() {
        TransactServiceGrpc.TransactServiceBlockingStub stub = TransactServiceGrpc.newBlockingStub(getChannel());
        return addCallOptions(stub);
    }

    public AuthServiceGrpc.AuthServiceStub getAuthStub(ManagedChannel channel) {
        AuthServiceGrpc.AuthServiceStub stub = AuthServiceGrpc.newStub(channel);
        return addCallOptions(stub);
    }

    public TransactServiceGrpc.TransactServiceStub getTransactNonBlockingStub() {
        TransactServiceGrpc.TransactServiceStub stub = TransactServiceGrpc.newStub(getChannel());
        return addCallOptions(stub);
    }

    public IndexServiceGrpc.IndexServiceBlockingStub getIndexServiceBlockingStub() {
        IndexServiceGrpc.IndexServiceBlockingStub stub = IndexServiceGrpc.newBlockingStub(getChannel());
        return addCallOptions(stub);
    }

    public UserAdminServiceGrpc.UserAdminServiceBlockingStub getUserAdminServiceBlockingStub() {
        UserAdminServiceGrpc.UserAdminServiceBlockingStub stub = UserAdminServiceGrpc.newBlockingStub(getChannel());
        return addCallOptions(stub);
    }

    public ClusterInfoServiceGrpc.ClusterInfoServiceBlockingStub clusterInfoBlockingStub(Channel tendChannel) {
        return addCallOptions(ClusterInfoServiceGrpc.newBlockingStub(tendChannel));
    }

    private > T addCallOptions(T stub) {
        T result = stub;
        if (authTokenManager != null) {
            result = stub.withCallCredentials(authTokenManager.getCallCredentials());
        }
        if (connectionConfig.getDefaultTimeout() != Integer.MAX_VALUE) {
            result = stub.withDeadline(Deadline.after(connectionConfig.getDefaultTimeout(), TimeUnit.MILLISECONDS));
        }
        return result;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy