All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aerospike.vector.client.internal.auth.AuthTokenManager Maven / Gradle / Ivy

Go to download

This project includes the Java client for Aerospike Vector Search for high-performance data interactions.

The newest version!
package com.aerospike.vector.client.internal.auth;

import com.aerospike.vector.client.auth.PasswordCredentials;
import com.aerospike.vector.client.internal.ClusterTenderer;
import com.aerospike.vector.client.proto.AuthRequest;
import com.aerospike.vector.client.proto.AuthResponse;
import com.aerospike.vector.client.proto.Credentials;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.*;
import io.grpc.stub.StreamObserver;
import java.io.Closeable;
import java.io.IOException;
import java.util.Base64;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** An access token manager for Aerospike proxy. */
public class AuthTokenManager implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(AuthTokenManager.class);

    private static final int REFRESH_MIN_TIME = 5000;
    private static final int MAX_EXPONENTIAL_BACKOFF = 15000;
    private static final float REFRESH_AFTER_FRACTION = 0.8f;
    private static final ObjectMapper objectMapper = new ObjectMapper();

    private final ScheduledExecutorService executor;
    private final AtomicBoolean isFetchingToken = new AtomicBoolean(false);
    private final AtomicBoolean isClosed = new AtomicBoolean(false);
    private final AtomicInteger consecutiveRefreshErrors = new AtomicInteger(0);
    private final AtomicReference refreshError = new AtomicReference<>(null);
    private volatile AccessToken accessToken;
    private volatile boolean fetchScheduled;
    private ScheduledFuture refreshFuture;

    private final PasswordCredentials passwordCredentials;
    private final ClusterTenderer clusterTenderer;

    private final String authIdentifier;

    /**
     * AuthTokenManager constructor.
     *
     * @param passwordCredentials credential of user or admin.
     * @param clusterTenderer internally constructed clusterTenderer object.
     */
    public AuthTokenManager(
            PasswordCredentials passwordCredentials, ClusterTenderer clusterTenderer) {
        this.passwordCredentials = passwordCredentials;
        this.clusterTenderer = clusterTenderer;
        authIdentifier = String.format("avs-%s-authmanager", clusterTenderer.identifier);
        this.executor =
                Executors.newScheduledThreadPool(
                        Runtime.getRuntime().availableProcessors(),
                        new ThreadFactoryBuilder().setNameFormat(authIdentifier).build());
        updateAccessToken(null);
        // first time generate token which will be refreshed periodically
        fetchToken(true);
    }

    // make update thread safe
    private synchronized void updateAccessToken(AccessToken newToken) {
        this.accessToken = newToken;
    }

    // read is threadsafe
    private synchronized AccessToken getAccessToken() {
        return accessToken;
    }

    private void fetchToken(boolean forceRefresh) {
        fetchScheduled = false;
        if (isClosed.get() || isFetchingToken.get()) {
            return;
        }
        if (shouldRefresh(forceRefresh)) {
            try {
                log.info("{}; Starting token refresh.", authIdentifier);
                AuthRequest aerospikeAuthRequest =
                        AuthRequest.newBuilder()
                                .setCredentials(
                                        Credentials.newBuilder()
                                                .setUsername(passwordCredentials.username())
                                                .setPasswordCredentials(
                                                        com.aerospike.vector.client.proto
                                                                .PasswordCredentials.newBuilder()
                                                                .setPassword(
                                                                        passwordCredentials
                                                                                .password())
                                                                .build())
                                                .build())
                                .build();
                ManagedChannel channel;
                try {
                    channel = clusterTenderer.getChannel();
                    log.info("{}; Got successfully channel for token refresh.", authIdentifier);
                } catch (Exception e) {
                    log.error(
                            "{}; Error getting in getting tend channel, will referesh in 10"
                                    + " milliseconds.",
                            authIdentifier,
                            e);
                    isFetchingToken.set(false);
                    unsafeScheduleRefresh(10, true);
                    return;
                }
                isFetchingToken.set(true);
                clusterTenderer
                        .getAuthStub(channel)
                        .withDeadline(Deadline.after(REFRESH_MIN_TIME, TimeUnit.MILLISECONDS))
                        .authenticate(
                                aerospikeAuthRequest,
                                new StreamObserver<>() {
                                    @Override
                                    public void onNext(AuthResponse aerospikeAuthResponse) {
                                        try {
                                            updateAccessToken(
                                                    parseToken(aerospikeAuthResponse.getToken()));
                                            log.info(
                                                    "{}; Fetched token successfully with TTL {}ms.",
                                                    authIdentifier,
                                                    accessToken.ttl);
                                            unsafeScheduleNextRefresh();
                                            clearRefreshErrors();
                                        } catch (Exception e) {
                                            log.error(
                                                    "{}; Error in fetching token.",
                                                    authIdentifier,
                                                    e);
                                            onFetchError(e);
                                        }
                                    }

                                    @Override
                                    public void onError(Throwable t) {
                                        onFetchError(t);
                                    }

                                    @Override
                                    public void onCompleted() {
                                        isFetchingToken.set(false);
                                    }
                                });
            } catch (Exception e) {
                onFetchError(e);
            }
        }
    }

    private void clearRefreshErrors() {
        consecutiveRefreshErrors.set(0);
        refreshError.set(null);
    }

    private void updateRefreshErrors(Throwable t) {
        consecutiveRefreshErrors.incrementAndGet();
        refreshError.set(t);
    }

    private void onFetchError(Throwable t) {
        updateRefreshErrors(t);
        Exception e = new Exception("Error fetching access token.", t);
        log.error("{}; onFetchError exception.", authIdentifier, e);
        unsafeScheduleNextRefresh();
        isFetchingToken.set(false);
    }

    private boolean shouldRefresh(boolean forceRefresh) {
        boolean shouldRefresh = forceRefresh || !isTokenValid();
        log.debug(
                "{}; shouldRefresh: {}, isTokenValid:{}.",
                authIdentifier,
                shouldRefresh,
                isTokenValid());
        return shouldRefresh;
    }

    private void unsafeScheduleNextRefresh() {
        long ttl = getAccessToken() != null ? getAccessToken().ttl : REFRESH_MIN_TIME;
        long delay = (long) Math.floor(ttl * REFRESH_AFTER_FRACTION);

        if (ttl - delay < REFRESH_MIN_TIME) {
            delay = ttl - REFRESH_MIN_TIME;
        }

        if (!isTokenValid()) {
            log.warn("{}; Token not valid, setting delay to zero.", authIdentifier);
            delay = 0;
        }

        if (delay == 0 && consecutiveRefreshErrors.get() > 0) {
            delay = (long) (Math.pow(2, consecutiveRefreshErrors.get()) * 1000);
            if (delay > MAX_EXPONENTIAL_BACKOFF) {
                delay = MAX_EXPONENTIAL_BACKOFF;
            }
            if (delay < 0) {
                delay = 0;
            }
        }
        log.info("{}; delay:{}", authIdentifier, delay);
        unsafeScheduleRefresh(delay, true);
    }

    private void unsafeScheduleRefresh(long delay, boolean forceRefresh) {
        if (isClosed.get() || !forceRefresh || fetchScheduled) {
            return;
        }
        if (!executor.isShutdown()) {
            refreshFuture =
                    executor.schedule(() -> fetchToken(forceRefresh), delay, TimeUnit.MILLISECONDS);
            fetchScheduled = true;
            log.info("{}; Scheduled token refresh after {} millis.", authIdentifier, delay);
        }
    }

    /**
     * Provides auth token status
     *
     * @return TokenStatus hinting if it is valid ot invalid with the corresponding exception
     */
    public TokenStatus getTokenStatus() {
        if (isTokenValid()) {
            return new TokenStatus();
        }
        Throwable error = refreshError.get();
        if (null != error) {
            return new TokenStatus(error);
        }
        AccessToken token = getAccessToken();
        if ((token != null && token.hasExpired())) {
            return new TokenStatus(
                    Status.UNAUTHENTICATED
                            .withDescription("token has expired.")
                            .asRuntimeException());
        }
        return new TokenStatus(Status.UNAUTHENTICATED.asRuntimeException());
    }

    @Override
    public void close() {
        if (isClosed.getAndSet(true)) {
            return;
        }
        boolean terminated = executor.isTerminated();
        if (!terminated) {
            if (refreshFuture != null) {
                refreshFuture.cancel(true);
            }
            executor.shutdown();
            boolean interrupted = false;
            while (!terminated) {
                try {
                    terminated = executor.awaitTermination(5, TimeUnit.SECONDS);
                } catch (InterruptedException e) {
                    if (!interrupted) {
                        executor.shutdownNow();
                        interrupted = true;
                    }
                }
            }
            if (interrupted) {
                Thread.currentThread().interrupt();
            }
        }
    }

    private AccessToken parseToken(String token) throws IOException {
        String[] parts = token.split("\\.");
        if (parts.length < 2) {
            throw new IllegalArgumentException("Invalid token format.");
        }
        String claims = new String(Base64.getUrlDecoder().decode(parts[1]));

        // Suppressing unchecked cast as we're handling a raw type from ObjectMapper
        @SuppressWarnings("unchecked")
        Map parsedClaims = objectMapper.readValue(claims, Map.class);

        Object expiryToken = parsedClaims.get("exp");
        Object iat = parsedClaims.get("iat");
        if (expiryToken instanceof Number && iat instanceof Number) {
            long ttl = (((Number) expiryToken).longValue() - ((Number) iat).longValue()) * 1000;
            if (ttl <= 0) {
                throw new IllegalArgumentException("Token 'iat' > 'exp'");
            }
            long expiry = System.currentTimeMillis() + ttl;
            return new AccessToken(expiry, ttl, token);
        } else {
            throw new IllegalArgumentException("Unsupported access token format.");
        }
    }

    /**
     * Returns gRPC call credentials
     *
     * @return gRPC CallCredentials
     * @throws StatusRuntimeException throws exception if token is unauthenticated or expired
     */
    public CallCredentials getCallCredentials() throws StatusRuntimeException {
        if (!isTokenValid()) {
            log.info("{}; Starting a call with invalid token.", authIdentifier);
            unsafeScheduleRefresh(0, false);
        }
        if (!isTokenValid()) {
            Throwable lastError = refreshError.get();
            if (lastError != null) {
                if (lastError instanceof StatusRuntimeException
                        || lastError instanceof StatusException) {
                    throw (RuntimeException) lastError;
                } else {
                    throw Status.UNAUTHENTICATED
                            .withDescription(lastError.getMessage())
                            .asRuntimeException();
                }
            } else {
                throw Status.UNAUTHENTICATED
                        .withDescription(
                                getAccessToken() == null
                                        ? "Access token not fetched."
                                        : "Access token has expired.")
                        .asRuntimeException();
            }
        }

        if (getAccessToken() != null) {
            return new BearerTokenCallCredentials(getAccessToken().token);
        }
        throw new IllegalStateException("Access token has expired.");
    }

    private boolean isTokenValid() {
        AccessToken token = getAccessToken();
        boolean tokenValid = token != null && !token.hasExpired();
        log.debug("{}; tokenValid: {}, token: {}.", authIdentifier, tokenValid, token);
        return tokenValid;
    }

    private static class AccessToken {
        private final long expiry;
        private final long ttl;
        private final String token;

        public AccessToken(long expiry, long ttl, String token) {
            this.expiry = expiry;
            this.ttl = ttl;
            this.token = token;
        }

        public boolean hasExpired() {
            boolean hasExpired = System.currentTimeMillis() > expiry;
            return hasExpired;
        }

        @Override
        public String toString() {
            return "AccessToken{"
                    + "expiry="
                    + expiry
                    + ", ttl="
                    + ttl
                    + ", token='"
                    + token
                    + '\''
                    + '}';
        }
    }

    /** Internal class maintains auth token status */
    public static class TokenStatus {
        private final Throwable error;
        private final boolean isValid;

        private TokenStatus() {
            this.isValid = true;
            this.error = null;
        }

        private TokenStatus(Throwable error) {
            this.isValid = false;
            this.error = error;
        }

        public boolean isValid() {
            return isValid;
        }

        private Throwable getError() {
            return error;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy