All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aerospike.vector.client.internal.auth.AuthTokenManager Maven / Gradle / Ivy

package com.aerospike.vector.client.internal.auth;

import com.aerospike.vector.client.auth.PasswordCredentials;
import com.aerospike.vector.client.internal.ClusterTenderer;
import com.aerospike.vector.client.proto.AuthRequest;
import com.aerospike.vector.client.proto.AuthResponse;
import com.aerospike.vector.client.proto.Credentials;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.*;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Closeable;
import java.io.IOException;
import java.util.Base64;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
 * An access token manager for Aerospike proxy.
 */
public class AuthTokenManager implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(AuthTokenManager.class);

    private static final int REFRESH_MIN_TIME = 5000;
    private static final int MAX_EXPONENTIAL_BACKOFF = 15000;
    private static final float REFRESH_AFTER_FRACTION = 0.8f;
    private static final ObjectMapper objectMapper = new ObjectMapper();

    private final ScheduledExecutorService executor;
    private final AtomicBoolean isFetchingToken = new AtomicBoolean(false);
    private final AtomicBoolean isClosed = new AtomicBoolean(false);
    private final AtomicInteger consecutiveRefreshErrors = new AtomicInteger(0);
    private final AtomicReference refreshError = new AtomicReference<>(null);
    private volatile AccessToken accessToken;
    private volatile boolean fetchScheduled;
    private ScheduledFuture refreshFuture;

    private final PasswordCredentials passwordCredentials;
    private final ClusterTenderer clusterTenderer;

    private final String authIdentifier;

    /**
     * AuthTokenManager constructor
     * @param passwordCredentials credential of user or admin
     * @param clusterTenderer internally constructed clusterTenderer object
     */
    public AuthTokenManager(PasswordCredentials passwordCredentials, ClusterTenderer clusterTenderer) {
        this.passwordCredentials = passwordCredentials;
        this.clusterTenderer = clusterTenderer;
        authIdentifier = String.format("avs-%s-authmanager", clusterTenderer.identifier);
        this.executor = Executors.newScheduledThreadPool( Runtime.getRuntime().availableProcessors(),
                new ThreadFactoryBuilder().setNameFormat(authIdentifier).build());
        updateAccessToken(null);
        //first time generate token which will be refreshed periodically
        fetchToken(true);
    }

    //make update thread safe
    private synchronized void updateAccessToken(AccessToken newToken) {
        this.accessToken = newToken;
    }

    //read is threadsafe
    private synchronized AccessToken getAccessToken() {
        return accessToken;
    }

    private void fetchToken(boolean forceRefresh) {
        fetchScheduled = false;
        if (isClosed.get() || isFetchingToken.get()) {
            return;
        }
        if (shouldRefresh(forceRefresh)) {
            try {
                log.info("{}; Starting token refresh", authIdentifier);
                AuthRequest aerospikeAuthRequest = AuthRequest.newBuilder()
                        .setCredentials(Credentials.newBuilder()
                                .setUsername(passwordCredentials.username())
                                .setPasswordCredentials(com.aerospike.vector.client.proto.PasswordCredentials.newBuilder()
                                        .setPassword(passwordCredentials.password()).build())
                                .build())
                        .build();
                ManagedChannel channel;
                try {
                    channel = clusterTenderer.getChannel();
                    log.info("{}; Got successfully channel for token refresh", authIdentifier);
                } catch (Exception e) {
                    log.error("{}; Error getting in getting tend channel, will referesh in 10 milliseconds", authIdentifier, e);
                    isFetchingToken.set(false);
                    unsafeScheduleRefresh(10, true);
                    return;
                }
                isFetchingToken.set(true);
                clusterTenderer.getAuthStub(channel).withDeadline(
                                Deadline.after(REFRESH_MIN_TIME, TimeUnit.MILLISECONDS))
                        .authenticate(aerospikeAuthRequest, new StreamObserver<>() {
                            @Override
                            public void onNext(AuthResponse aerospikeAuthResponse) {
                                try {
                                    updateAccessToken(parseToken(aerospikeAuthResponse.getToken()));
                                    log.info("{}; Fetched token successfully with TTL {}ms", authIdentifier, accessToken.ttl);
                                    unsafeScheduleNextRefresh();
                                    clearRefreshErrors();
                                } catch (Exception e) {
                                    log.error("{}; Error in fetching token", authIdentifier, e);
                                    onFetchError(e);
                                }
                            }

                            @Override
                            public void onError(Throwable t) {
                                onFetchError(t);
                            }

                            @Override
                            public void onCompleted() {
                                isFetchingToken.set(false);
                            }
                        });
            } catch (Exception e) {
                onFetchError(e);
            }
        }
    }

    private void clearRefreshErrors() {
        consecutiveRefreshErrors.set(0);
        refreshError.set(null);
    }

    private void updateRefreshErrors(Throwable t) {
        consecutiveRefreshErrors.incrementAndGet();
        refreshError.set(t);
    }

    private void onFetchError(Throwable t) {
        updateRefreshErrors(t);
        Exception e = new Exception("Error fetching access token", t);
        log.error("{}; onFetchError exception",authIdentifier,e);
        unsafeScheduleNextRefresh();
        isFetchingToken.set(false);
    }

    private boolean shouldRefresh(boolean forceRefresh) {
        boolean shouldRefresh = forceRefresh || !isTokenValid();
        log.debug("{}; shouldRefresh: {}, isTokenValid:{}", authIdentifier, shouldRefresh, isTokenValid());
        return shouldRefresh;
    }

    private void unsafeScheduleNextRefresh() {
        long ttl = getAccessToken() != null ? getAccessToken().ttl : REFRESH_MIN_TIME;
        long delay = (long) Math.floor(ttl * REFRESH_AFTER_FRACTION);

        if (ttl - delay < REFRESH_MIN_TIME) {
            delay = ttl - REFRESH_MIN_TIME;
        }

        if (!isTokenValid()) {
            log.warn("{}; Token not valid, setting delay to zero!", authIdentifier);
            delay = 0;
        }

        if (delay == 0 && consecutiveRefreshErrors.get() > 0) {
            delay = (long) (Math.pow(2, consecutiveRefreshErrors.get()) * 1000);
            if (delay > MAX_EXPONENTIAL_BACKOFF) {
                delay = MAX_EXPONENTIAL_BACKOFF;
            }
            if (delay < 0) {
                delay = 0;
            }
        }
        log.info("{}; delay:{}", authIdentifier, delay);
        unsafeScheduleRefresh(delay, true);
    }

    private void unsafeScheduleRefresh(long delay, boolean forceRefresh) {
        if (isClosed.get() || !forceRefresh || fetchScheduled) {
            return;
        }
        if (!executor.isShutdown()) {
            refreshFuture = executor.schedule(() -> fetchToken(forceRefresh), delay, TimeUnit.MILLISECONDS);
            fetchScheduled = true;
            log.info("{}; Scheduled token refresh after {} millis", authIdentifier, delay);
        }
    }

    /**
     * Provides auth token status
     * @return TokenStatus hinting if it is valid ot invalid with the corresponding exception
     */
    public TokenStatus getTokenStatus() {
        if(isTokenValid()) {
            return  new TokenStatus();
        }
        Throwable error = refreshError.get();
        if(null != error) {
            return new TokenStatus(error);
        }
        AccessToken token = getAccessToken();
        if((token != null && token.hasExpired())) {
            return  new TokenStatus(
                    Status.UNAUTHENTICATED.withDescription("token has expired").asRuntimeException());
        }
        return new  TokenStatus(Status.UNAUTHENTICATED.asRuntimeException());
    }

    @Override
    public void close() {
        if (isClosed.getAndSet(true)) {
            return;
        }
        boolean terminated = executor.isTerminated();
        if (!terminated) {
            if (refreshFuture != null) {
                refreshFuture.cancel(true);
            }
            executor.shutdown();
            boolean interrupted = false;
            while (!terminated) {
                try {
                    terminated = executor.awaitTermination(5, TimeUnit.SECONDS);
                } catch (InterruptedException e) {
                    if (!interrupted) {
                        executor.shutdownNow();
                        interrupted = true;
                    }
                }
            }
            if (interrupted) {
                Thread.currentThread().interrupt();
            }
        }
    }

    private AccessToken parseToken(String token) throws IOException {
        String[] parts = token.split("\\.");
        if (parts.length < 2) {
            throw new IllegalArgumentException("Invalid token format.");
        }
        String claims = new String(Base64.getUrlDecoder().decode(parts[1]));

        // Suppressing unchecked cast as we're handling a raw type from ObjectMapper
        @SuppressWarnings("unchecked")
        Map parsedClaims = objectMapper.readValue(claims, Map.class);

        Object expiryToken = parsedClaims.get("exp");
        Object iat = parsedClaims.get("iat");
        if (expiryToken instanceof Number && iat instanceof Number) {
            long ttl = (((Number) expiryToken).longValue() - ((Number) iat).longValue()) * 1000;
            if (ttl <= 0) {
                throw new IllegalArgumentException("Token 'iat' > 'exp'");
            }
            long expiry = System.currentTimeMillis() + ttl;
            return new AccessToken(expiry, ttl, token);
        } else {
            throw new IllegalArgumentException("Unsupported access token format");
        }
    }

    /**
     * Returns gRPC call credentials
     * @return gRPC CallCredentials
     * @throws StatusRuntimeException throws exception if token is unauthenticated or expired
     */
    public CallCredentials getCallCredentials() throws StatusRuntimeException {
        if (!isTokenValid()) {
            log.info("{}; Starting a call with invalid token", authIdentifier);
            unsafeScheduleRefresh(0, false);
        }
        if (!isTokenValid()) {
            Throwable lastError = refreshError.get();
            if (lastError != null) {
                if (lastError instanceof StatusRuntimeException || lastError instanceof StatusException) {
                    throw (RuntimeException) lastError;
                } else {
                    throw Status.UNAUTHENTICATED.withDescription(lastError.getMessage()).asRuntimeException();
                }
            } else {
                throw Status.UNAUTHENTICATED.withDescription(getAccessToken() == null ? "Access token not fetched" : "Access token has expired").asRuntimeException();
            }
        }

        if(getAccessToken() != null){
            return  new BearerTokenCallCredentials(getAccessToken().token);
        }
        throw new IllegalStateException("Access token has expired");

    }

    private boolean isTokenValid() {
        AccessToken token = getAccessToken();
        boolean tokenValid =  token != null && !token.hasExpired();
        log.debug("{}; tokenValid: {}, token: {}", authIdentifier, tokenValid, token);
        return tokenValid;
    }


    private static class AccessToken {
        private final long expiry;
        private final long ttl;
        private final String token;

        public AccessToken(long expiry, long ttl, String token) {
            this.expiry = expiry;
            this.ttl = ttl;
            this.token = token;
        }

        public boolean hasExpired() {
            boolean hasExpired = System.currentTimeMillis() > expiry;
            return hasExpired;
        }

        @Override
        public String toString() {
            return "AccessToken{" +
                    "expiry=" + expiry +
                    ", ttl=" + ttl +
                    ", token='" + token + '\'' +
                    '}';
        }
    }

    /**
     *  Internal class maintains auth token status
     */
    public static class TokenStatus {
        private final Throwable error;
        private final boolean isValid;

        private TokenStatus() {
            this.isValid = true;
            this.error = null;
        }

        private TokenStatus(Throwable error) {
            this.isValid = false;
            this.error = error;
        }

        public boolean isValid() {
            return isValid;
        }

        private Throwable getError() {
            return error;
        }
    }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy