All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aerospike.vector.client.internal.ClusterTenderer Maven / Gradle / Ivy

Go to download

This project includes the Java client for Aerospike Vector Search for high-performance data interactions.

The newest version!
package com.aerospike.vector.client.internal;

import com.aerospike.vector.client.ConnectionConfig;
import com.aerospike.vector.client.HostPort;
import com.aerospike.vector.client.auth.PasswordCredentials;
import com.aerospike.vector.client.internal.auth.AuthTokenManager;
import com.aerospike.vector.client.proto.*;
import com.google.protobuf.Empty;
import io.grpc.Channel;
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.stub.AbstractStub;
import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Internal class for providing gRPC tls/non-tls channels */
public class ClusterTenderer implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(ClusterTenderer.class);
    private final ConnectionConfig connectionConfig;
    private final List seedChannels = new ArrayList<>();
    private final ScheduledExecutorService tendExecutorService =
            Executors.newSingleThreadScheduledExecutor();

    private AuthTokenManager authTokenManager;
    private Exception lastTendException = null;
    private final VectorChannelProvider vectorChannelProvider;
    private final AtomicReference defaultClusterid =
            new AtomicReference<>(ClusterId.getDefaultInstance());
    private final AtomicReference> nodeEndpoints =
            new AtomicReference<>(new HashMap<>());
    private final AtomicBoolean closed = new AtomicBoolean();
    private final Random random = new Random();
    ExecutorService executor = Executors.newSingleThreadExecutor();

    public final String identifier;

    public ClusterTenderer(ConnectionConfig connectionConfig, String identifier) {
        this.identifier = identifier;
        this.connectionConfig = connectionConfig;
        vectorChannelProvider = new VectorChannelProvider(connectionConfig, identifier);

        if (!connectionConfig.isLoadBalancer()) {
            tendExecutorService.scheduleWithFixedDelay(
                    this::refreshClusterState, 0, 1, TimeUnit.SECONDS);
        }

        if (connectionConfig.getCredentials() instanceof PasswordCredentials) {
            this.authTokenManager =
                    new AuthTokenManager(
                            (PasswordCredentials) connectionConfig.getCredentials(), this);
            long start = System.currentTimeMillis();
            // wait is only necessary when authTokenManager is created
            waitTillReady();
            log.info(
                    "{}; waitTillReady took: {} seconds.",
                    identifier,
                    (System.currentTimeMillis() - start) / 60.0);
        }
    }

    private ManagedChannel getTendChannel() {
        if (!connectionConfig.isLoadBalancer()) {
            ManagedChannel channel = getChannelFromEndpoints();
            if (channel != null) {
                return channel;
            }
        }
        HostPort seed =
                connectionConfig.getSeeds().get(random.nextInt(connectionConfig.getSeeds().size()));
        return vectorChannelProvider.channelFor(
                ServerEndpointList.newBuilder()
                        .addEndpoints(
                                ServerEndpoint.newBuilder()
                                        .setAddress(seed.address())
                                        .setPort(seed.port())
                                        .setIsTls(connectionConfig.getClientTlsConfig() != null))
                        .build());
    }

    private ManagedChannel getChannelFromEndpoints() {
        Map endpoints = nodeEndpoints.get();
        if (endpoints.isEmpty()) {
            return null;
        }

        List valuesList = new ArrayList<>(endpoints.values());
        ServerEndpointList randomEndpoints = valuesList.get(random.nextInt(valuesList.size()));

        if (randomEndpoints.getEndpointsList().isEmpty()) {
            return null;
        }

        return vectorChannelProvider.channelFor(randomEndpoints);
    }

    private void refreshClusterState() {
        try {
            HashMap tempEndpoints = new HashMap<>();
            boolean updateEndpoints = false;

            for (HostPort seed : connectionConfig.getSeeds()) {
                ManagedChannel tendChannel =
                        vectorChannelProvider.channelFor(
                                ServerEndpointList.newBuilder()
                                        .addEndpoints(
                                                ServerEndpoint.newBuilder()
                                                        .setAddress(seed.address())
                                                        .setPort(seed.port())
                                                        .setIsTls(
                                                                connectionConfig
                                                                                .getClientTlsConfig()
                                                                        != null))
                                        .build());

                ClusterInfoServiceGrpc.ClusterInfoServiceBlockingStub clusterInfoStub =
                        clusterInfoBlockingStub(tendChannel);
                ClusterId newClusterId = clusterInfoStub.getClusterId(Empty.getDefaultInstance());

                if (defaultClusterid.get().equals(newClusterId)) {
                    continue; // Skip to the next iteration in the loop, equivalent to
                    // return@forEach in Kotlin
                }

                updateEndpoints = true;
                defaultClusterid.set(newClusterId);

                ClusterNodeEndpointsRequest.Builder requestBuilder =
                        ClusterNodeEndpointsRequest.newBuilder();
                if (connectionConfig.getListenerName() != null) {
                    requestBuilder.setListenerName(connectionConfig.getListenerName());
                }

                ClusterNodeEndpoints endpoints =
                        clusterInfoStub.getClusterEndpoints(requestBuilder.build());
                if (endpoints.getEndpointsMap().size() > tempEndpoints.size()) {
                    tempEndpoints = new HashMap<>(endpoints.getEndpointsMap());
                }
            }

            if (updateEndpoints) {
                nodeEndpoints.set(tempEndpoints);
            }
        } catch (Exception e) {
            lastTendException = e;
            log.debug("Error getting node endpoints.", e);
        }
    }

    public ManagedChannel getChannel() {
        if (connectionConfig.isLoadBalancer()) {
            return getTendChannel();
        }

        long t1 = System.currentTimeMillis();
        ManagedChannel channel = getChannelFromEndpoints();
        while (channel != null
                && (System.currentTimeMillis() - t1) < connectionConfig.getConnectTimeout()) {
            try {
                TimeUnit.MILLISECONDS.sleep(100);
                channel = getChannelFromEndpoints();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }

        // TBD need to investigate why channel comes as null instead of
        //        log.info("channel was {}", channel);
        return channel != null ? channel : getTendChannel();
    }

    @Override
    public void close() {
        closed.set(true);
        for (Channel seedChannel : seedChannels) {
            ((ManagedChannel) seedChannel).shutdown();
        }

        if (authTokenManager != null) {
            authTokenManager.close();
        }

        tendExecutorService.shutdownNow();
        vectorChannelProvider.close();
        executor.shutdown();
    }

    private void waitTillReady() {

        log.info("{}: waiting for auth-manager to get ready.", identifier);
        long timeout = connectionConfig.getConnectTimeout();
        long deadline = System.currentTimeMillis() + timeout;
        boolean notReady = !isReady();
        boolean notPastWaitTime = System.currentTimeMillis() < deadline;
        while (notReady && notPastWaitTime) {
            try {
                log.debug(
                        "{}; notReady:{}, notPastWaitTime:{}.",
                        identifier,
                        notReady,
                        notPastWaitTime);
                Thread.sleep(100);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException("Interrupted while waiting.", e);
            } finally {

                notReady = !isReady();
                notPastWaitTime = System.currentTimeMillis() < deadline;
            }
        }

        if (!isReady()) {
            Exception augmentedException =
                    new TimeoutException("Connect timed out after " + timeout + " ms.");
            if (lastTendException != null) {
                augmentedException.initCause(lastTendException);
            }

            if (connectionConfig.isFailIfNotConnected()) {
                throw new RuntimeException(augmentedException);
            } else {
                log.warn("Client not connected in {}.", identifier, augmentedException);
            }
        }
    }

    private boolean isReady() {
        return (connectionConfig.isLoadBalancer() || !nodeEndpoints.get().isEmpty())
                && (authTokenManager == null || authTokenManager.getTokenStatus().isValid());
    }

    public TransactServiceGrpc.TransactServiceBlockingStub getTransactBlockingStub() {
        TransactServiceGrpc.TransactServiceBlockingStub stub =
                TransactServiceGrpc.newBlockingStub(getChannel());
        return addCallOptions(stub);
    }

    public AuthServiceGrpc.AuthServiceStub getAuthStub(ManagedChannel channel) {
        AuthServiceGrpc.AuthServiceStub stub = AuthServiceGrpc.newStub(channel);
        return addCallOptions(stub);
    }

    public TransactServiceGrpc.TransactServiceStub getTransactNonBlockingStub() {
        TransactServiceGrpc.TransactServiceStub stub = TransactServiceGrpc.newStub(getChannel());
        return addCallOptions(stub);
    }

    public IndexServiceGrpc.IndexServiceBlockingStub getIndexServiceBlockingStub() {
        IndexServiceGrpc.IndexServiceBlockingStub stub =
                IndexServiceGrpc.newBlockingStub(getChannel());
        return addCallOptions(stub);
    }

    public UserAdminServiceGrpc.UserAdminServiceBlockingStub getUserAdminServiceBlockingStub() {
        UserAdminServiceGrpc.UserAdminServiceBlockingStub stub =
                UserAdminServiceGrpc.newBlockingStub(getChannel());
        return addCallOptions(stub);
    }

    public ClusterInfoServiceGrpc.ClusterInfoServiceBlockingStub clusterInfoBlockingStub(
            Channel tendChannel) {
        return addCallOptions(ClusterInfoServiceGrpc.newBlockingStub(tendChannel));
    }

    private > T addCallOptions(T stub) {
        T result = stub;
        if (authTokenManager != null) {
            result = stub.withCallCredentials(authTokenManager.getCallCredentials());
        }
        if (connectionConfig.getDefaultTimeout() != Integer.MAX_VALUE) {
            result =
                    stub.withDeadline(
                            Deadline.after(
                                    connectionConfig.getDefaultTimeout(), TimeUnit.MILLISECONDS));
        }
        return result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy