
com.aerospike.vector.client.internal.auth.AuthTokenManager Maven / Gradle / Ivy
package com.aerospike.vector.client.internal.auth;
import com.aerospike.vector.client.auth.PasswordCredentials;
import com.aerospike.vector.client.internal.ClusterTenderer;
import com.aerospike.vector.client.proto.AuthRequest;
import com.aerospike.vector.client.proto.AuthResponse;
import com.aerospike.vector.client.proto.Credentials;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.*;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Closeable;
import java.io.IOException;
import java.util.Base64;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
/**
* An access token manager for Aerospike proxy.
*/
public class AuthTokenManager implements Closeable {
private static final Logger log = LoggerFactory.getLogger(AuthTokenManager.class);
private static final int REFRESH_MIN_TIME = 5000;
private static final int MAX_EXPONENTIAL_BACKOFF = 15000;
private static final float REFRESH_AFTER_FRACTION = 0.8f;
private static final ObjectMapper objectMapper = new ObjectMapper();
private final ScheduledExecutorService executor;
private final AtomicBoolean isFetchingToken = new AtomicBoolean(false);
private final AtomicBoolean isClosed = new AtomicBoolean(false);
private final AtomicInteger consecutiveRefreshErrors = new AtomicInteger(0);
private final AtomicReference refreshError = new AtomicReference<>(null);
private volatile AccessToken accessToken;
private volatile boolean fetchScheduled;
private ScheduledFuture> refreshFuture;
private final PasswordCredentials passwordCredentials;
private final ClusterTenderer clusterTenderer;
private final String authIdentifier;
/**
* AuthTokenManager constructor
* @param passwordCredentials credential of user or admin
* @param clusterTenderer internally constructed clusterTenderer object
*/
public AuthTokenManager(PasswordCredentials passwordCredentials, ClusterTenderer clusterTenderer) {
this.passwordCredentials = passwordCredentials;
this.clusterTenderer = clusterTenderer;
authIdentifier = String.format("avs-%s-authmanager", clusterTenderer.identifier);
this.executor = Executors.newScheduledThreadPool( Runtime.getRuntime().availableProcessors(),
new ThreadFactoryBuilder().setNameFormat(authIdentifier).build());
updateAccessToken(null);
//first time generate token which will be refreshed periodically
fetchToken(true);
}
//make update thread safe
private synchronized void updateAccessToken(AccessToken newToken) {
this.accessToken = newToken;
}
//read is threadsafe
private synchronized AccessToken getAccessToken() {
return accessToken;
}
private void fetchToken(boolean forceRefresh) {
fetchScheduled = false;
if (isClosed.get() || isFetchingToken.get()) {
return;
}
if (shouldRefresh(forceRefresh)) {
try {
log.info("{}; Starting token refresh", authIdentifier);
AuthRequest aerospikeAuthRequest = AuthRequest.newBuilder()
.setCredentials(Credentials.newBuilder()
.setUsername(passwordCredentials.username())
.setPasswordCredentials(com.aerospike.vector.client.proto.PasswordCredentials.newBuilder()
.setPassword(passwordCredentials.password()).build())
.build())
.build();
ManagedChannel channel;
try {
channel = clusterTenderer.getChannel();
log.info("{}; Got successfully channel for token refresh", authIdentifier);
} catch (Exception e) {
log.error("{}; Error getting in getting tend channel, will referesh in 10 milliseconds", authIdentifier, e);
isFetchingToken.set(false);
unsafeScheduleRefresh(10, true);
return;
}
isFetchingToken.set(true);
clusterTenderer.getAuthStub(channel).withDeadline(
Deadline.after(REFRESH_MIN_TIME, TimeUnit.MILLISECONDS))
.authenticate(aerospikeAuthRequest, new StreamObserver<>() {
@Override
public void onNext(AuthResponse aerospikeAuthResponse) {
try {
updateAccessToken(parseToken(aerospikeAuthResponse.getToken()));
log.info("{}; Fetched token successfully with TTL {}ms", authIdentifier, accessToken.ttl);
unsafeScheduleNextRefresh();
clearRefreshErrors();
} catch (Exception e) {
log.error("{}; Error in fetching token", authIdentifier, e);
onFetchError(e);
}
}
@Override
public void onError(Throwable t) {
onFetchError(t);
}
@Override
public void onCompleted() {
isFetchingToken.set(false);
}
});
} catch (Exception e) {
onFetchError(e);
}
}
}
private void clearRefreshErrors() {
consecutiveRefreshErrors.set(0);
refreshError.set(null);
}
private void updateRefreshErrors(Throwable t) {
consecutiveRefreshErrors.incrementAndGet();
refreshError.set(t);
}
private void onFetchError(Throwable t) {
updateRefreshErrors(t);
Exception e = new Exception("Error fetching access token", t);
log.error("{}; onFetchError exception",authIdentifier,e);
unsafeScheduleNextRefresh();
isFetchingToken.set(false);
}
private boolean shouldRefresh(boolean forceRefresh) {
boolean shouldRefresh = forceRefresh || !isTokenValid();
log.debug("{}; shouldRefresh: {}, isTokenValid:{}", authIdentifier, shouldRefresh, isTokenValid());
return shouldRefresh;
}
private void unsafeScheduleNextRefresh() {
long ttl = getAccessToken() != null ? getAccessToken().ttl : REFRESH_MIN_TIME;
long delay = (long) Math.floor(ttl * REFRESH_AFTER_FRACTION);
if (ttl - delay < REFRESH_MIN_TIME) {
delay = ttl - REFRESH_MIN_TIME;
}
if (!isTokenValid()) {
log.warn("{}; Token not valid, setting delay to zero!", authIdentifier);
delay = 0;
}
if (delay == 0 && consecutiveRefreshErrors.get() > 0) {
delay = (long) (Math.pow(2, consecutiveRefreshErrors.get()) * 1000);
if (delay > MAX_EXPONENTIAL_BACKOFF) {
delay = MAX_EXPONENTIAL_BACKOFF;
}
if (delay < 0) {
delay = 0;
}
}
log.info("{}; delay:{}", authIdentifier, delay);
unsafeScheduleRefresh(delay, true);
}
private void unsafeScheduleRefresh(long delay, boolean forceRefresh) {
if (isClosed.get() || !forceRefresh || fetchScheduled) {
return;
}
if (!executor.isShutdown()) {
refreshFuture = executor.schedule(() -> fetchToken(forceRefresh), delay, TimeUnit.MILLISECONDS);
fetchScheduled = true;
log.info("{}; Scheduled token refresh after {} millis", authIdentifier, delay);
}
}
/**
* Provides auth token status
* @return TokenStatus hinting if it is valid ot invalid with the corresponding exception
*/
public TokenStatus getTokenStatus() {
if(isTokenValid()) {
return new TokenStatus();
}
Throwable error = refreshError.get();
if(null != error) {
return new TokenStatus(error);
}
AccessToken token = getAccessToken();
if((token != null && token.hasExpired())) {
return new TokenStatus(
Status.UNAUTHENTICATED.withDescription("token has expired").asRuntimeException());
}
return new TokenStatus(Status.UNAUTHENTICATED.asRuntimeException());
}
@Override
public void close() {
if (isClosed.getAndSet(true)) {
return;
}
boolean terminated = executor.isTerminated();
if (!terminated) {
if (refreshFuture != null) {
refreshFuture.cancel(true);
}
executor.shutdown();
boolean interrupted = false;
while (!terminated) {
try {
terminated = executor.awaitTermination(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
if (!interrupted) {
executor.shutdownNow();
interrupted = true;
}
}
}
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}
private AccessToken parseToken(String token) throws IOException {
String[] parts = token.split("\\.");
if (parts.length < 2) {
throw new IllegalArgumentException("Invalid token format.");
}
String claims = new String(Base64.getUrlDecoder().decode(parts[1]));
// Suppressing unchecked cast as we're handling a raw type from ObjectMapper
@SuppressWarnings("unchecked")
Map parsedClaims = objectMapper.readValue(claims, Map.class);
Object expiryToken = parsedClaims.get("exp");
Object iat = parsedClaims.get("iat");
if (expiryToken instanceof Number && iat instanceof Number) {
long ttl = (((Number) expiryToken).longValue() - ((Number) iat).longValue()) * 1000;
if (ttl <= 0) {
throw new IllegalArgumentException("Token 'iat' > 'exp'");
}
long expiry = System.currentTimeMillis() + ttl;
return new AccessToken(expiry, ttl, token);
} else {
throw new IllegalArgumentException("Unsupported access token format");
}
}
/**
* Returns gRPC call credentials
* @return gRPC CallCredentials
* @throws StatusRuntimeException throws exception if token is unauthenticated or expired
*/
public CallCredentials getCallCredentials() throws StatusRuntimeException {
if (!isTokenValid()) {
log.info("{}; Starting a call with invalid token", authIdentifier);
unsafeScheduleRefresh(0, false);
}
if (!isTokenValid()) {
Throwable lastError = refreshError.get();
if (lastError != null) {
if (lastError instanceof StatusRuntimeException || lastError instanceof StatusException) {
throw (RuntimeException) lastError;
} else {
throw Status.UNAUTHENTICATED.withDescription(lastError.getMessage()).asRuntimeException();
}
} else {
throw Status.UNAUTHENTICATED.withDescription(getAccessToken() == null ? "Access token not fetched" : "Access token has expired").asRuntimeException();
}
}
if(getAccessToken() != null){
return new BearerTokenCallCredentials(getAccessToken().token);
}
throw new IllegalStateException("Access token has expired");
}
private boolean isTokenValid() {
AccessToken token = getAccessToken();
boolean tokenValid = token != null && !token.hasExpired();
log.debug("{}; tokenValid: {}, token: {}", authIdentifier, tokenValid, token);
return tokenValid;
}
private static class AccessToken {
private final long expiry;
private final long ttl;
private final String token;
public AccessToken(long expiry, long ttl, String token) {
this.expiry = expiry;
this.ttl = ttl;
this.token = token;
}
public boolean hasExpired() {
boolean hasExpired = System.currentTimeMillis() > expiry;
return hasExpired;
}
@Override
public String toString() {
return "AccessToken{" +
"expiry=" + expiry +
", ttl=" + ttl +
", token='" + token + '\'' +
'}';
}
}
/**
* Internal class maintains auth token status
*/
public static class TokenStatus {
private final Throwable error;
private final boolean isValid;
private TokenStatus() {
this.isValid = true;
this.error = null;
}
private TokenStatus(Throwable error) {
this.isValid = false;
this.error = error;
}
public boolean isValid() {
return isValid;
}
private Throwable getError() {
return error;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy