com.aerospike.vector.client.internal.auth.AuthTokenManager Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of avs-client-java Show documentation
Show all versions of avs-client-java Show documentation
This project includes the Java client for Aerospike Vector Search for high-performance data interactions.
The newest version!
package com.aerospike.vector.client.internal.auth;
import com.aerospike.vector.client.auth.PasswordCredentials;
import com.aerospike.vector.client.internal.ClusterTenderer;
import com.aerospike.vector.client.proto.AuthRequest;
import com.aerospike.vector.client.proto.AuthResponse;
import com.aerospike.vector.client.proto.Credentials;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.*;
import io.grpc.stub.StreamObserver;
import java.io.Closeable;
import java.io.IOException;
import java.util.Base64;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** An access token manager for Aerospike proxy. */
public class AuthTokenManager implements Closeable {
private static final Logger log = LoggerFactory.getLogger(AuthTokenManager.class);
private static final int REFRESH_MIN_TIME = 5000;
private static final int MAX_EXPONENTIAL_BACKOFF = 15000;
private static final float REFRESH_AFTER_FRACTION = 0.8f;
private static final ObjectMapper objectMapper = new ObjectMapper();
private final ScheduledExecutorService executor;
private final AtomicBoolean isFetchingToken = new AtomicBoolean(false);
private final AtomicBoolean isClosed = new AtomicBoolean(false);
private final AtomicInteger consecutiveRefreshErrors = new AtomicInteger(0);
private final AtomicReference refreshError = new AtomicReference<>(null);
private volatile AccessToken accessToken;
private volatile boolean fetchScheduled;
private ScheduledFuture> refreshFuture;
private final PasswordCredentials passwordCredentials;
private final ClusterTenderer clusterTenderer;
private final String authIdentifier;
/**
* AuthTokenManager constructor.
*
* @param passwordCredentials credential of user or admin.
* @param clusterTenderer internally constructed clusterTenderer object.
*/
public AuthTokenManager(
PasswordCredentials passwordCredentials, ClusterTenderer clusterTenderer) {
this.passwordCredentials = passwordCredentials;
this.clusterTenderer = clusterTenderer;
authIdentifier = String.format("avs-%s-authmanager", clusterTenderer.identifier);
this.executor =
Executors.newScheduledThreadPool(
Runtime.getRuntime().availableProcessors(),
new ThreadFactoryBuilder().setNameFormat(authIdentifier).build());
updateAccessToken(null);
// first time generate token which will be refreshed periodically
fetchToken(true);
}
// make update thread safe
private synchronized void updateAccessToken(AccessToken newToken) {
this.accessToken = newToken;
}
// read is threadsafe
private synchronized AccessToken getAccessToken() {
return accessToken;
}
private void fetchToken(boolean forceRefresh) {
fetchScheduled = false;
if (isClosed.get() || isFetchingToken.get()) {
return;
}
if (shouldRefresh(forceRefresh)) {
try {
log.info("{}; Starting token refresh.", authIdentifier);
AuthRequest aerospikeAuthRequest =
AuthRequest.newBuilder()
.setCredentials(
Credentials.newBuilder()
.setUsername(passwordCredentials.username())
.setPasswordCredentials(
com.aerospike.vector.client.proto
.PasswordCredentials.newBuilder()
.setPassword(
passwordCredentials
.password())
.build())
.build())
.build();
ManagedChannel channel;
try {
channel = clusterTenderer.getChannel();
log.info("{}; Got successfully channel for token refresh.", authIdentifier);
} catch (Exception e) {
log.error(
"{}; Error getting in getting tend channel, will referesh in 10"
+ " milliseconds.",
authIdentifier,
e);
isFetchingToken.set(false);
unsafeScheduleRefresh(10, true);
return;
}
isFetchingToken.set(true);
clusterTenderer
.getAuthStub(channel)
.withDeadline(Deadline.after(REFRESH_MIN_TIME, TimeUnit.MILLISECONDS))
.authenticate(
aerospikeAuthRequest,
new StreamObserver<>() {
@Override
public void onNext(AuthResponse aerospikeAuthResponse) {
try {
updateAccessToken(
parseToken(aerospikeAuthResponse.getToken()));
log.info(
"{}; Fetched token successfully with TTL {}ms.",
authIdentifier,
accessToken.ttl);
unsafeScheduleNextRefresh();
clearRefreshErrors();
} catch (Exception e) {
log.error(
"{}; Error in fetching token.",
authIdentifier,
e);
onFetchError(e);
}
}
@Override
public void onError(Throwable t) {
onFetchError(t);
}
@Override
public void onCompleted() {
isFetchingToken.set(false);
}
});
} catch (Exception e) {
onFetchError(e);
}
}
}
private void clearRefreshErrors() {
consecutiveRefreshErrors.set(0);
refreshError.set(null);
}
private void updateRefreshErrors(Throwable t) {
consecutiveRefreshErrors.incrementAndGet();
refreshError.set(t);
}
private void onFetchError(Throwable t) {
updateRefreshErrors(t);
Exception e = new Exception("Error fetching access token.", t);
log.error("{}; onFetchError exception.", authIdentifier, e);
unsafeScheduleNextRefresh();
isFetchingToken.set(false);
}
private boolean shouldRefresh(boolean forceRefresh) {
boolean shouldRefresh = forceRefresh || !isTokenValid();
log.debug(
"{}; shouldRefresh: {}, isTokenValid:{}.",
authIdentifier,
shouldRefresh,
isTokenValid());
return shouldRefresh;
}
private void unsafeScheduleNextRefresh() {
long ttl = getAccessToken() != null ? getAccessToken().ttl : REFRESH_MIN_TIME;
long delay = (long) Math.floor(ttl * REFRESH_AFTER_FRACTION);
if (ttl - delay < REFRESH_MIN_TIME) {
delay = ttl - REFRESH_MIN_TIME;
}
if (!isTokenValid()) {
log.warn("{}; Token not valid, setting delay to zero.", authIdentifier);
delay = 0;
}
if (delay == 0 && consecutiveRefreshErrors.get() > 0) {
delay = (long) (Math.pow(2, consecutiveRefreshErrors.get()) * 1000);
if (delay > MAX_EXPONENTIAL_BACKOFF) {
delay = MAX_EXPONENTIAL_BACKOFF;
}
if (delay < 0) {
delay = 0;
}
}
log.info("{}; delay:{}", authIdentifier, delay);
unsafeScheduleRefresh(delay, true);
}
private void unsafeScheduleRefresh(long delay, boolean forceRefresh) {
if (isClosed.get() || !forceRefresh || fetchScheduled) {
return;
}
if (!executor.isShutdown()) {
refreshFuture =
executor.schedule(() -> fetchToken(forceRefresh), delay, TimeUnit.MILLISECONDS);
fetchScheduled = true;
log.info("{}; Scheduled token refresh after {} millis.", authIdentifier, delay);
}
}
/**
* Provides auth token status
*
* @return TokenStatus hinting if it is valid ot invalid with the corresponding exception
*/
public TokenStatus getTokenStatus() {
if (isTokenValid()) {
return new TokenStatus();
}
Throwable error = refreshError.get();
if (null != error) {
return new TokenStatus(error);
}
AccessToken token = getAccessToken();
if ((token != null && token.hasExpired())) {
return new TokenStatus(
Status.UNAUTHENTICATED
.withDescription("token has expired.")
.asRuntimeException());
}
return new TokenStatus(Status.UNAUTHENTICATED.asRuntimeException());
}
@Override
public void close() {
if (isClosed.getAndSet(true)) {
return;
}
boolean terminated = executor.isTerminated();
if (!terminated) {
if (refreshFuture != null) {
refreshFuture.cancel(true);
}
executor.shutdown();
boolean interrupted = false;
while (!terminated) {
try {
terminated = executor.awaitTermination(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
if (!interrupted) {
executor.shutdownNow();
interrupted = true;
}
}
}
if (interrupted) {
Thread.currentThread().interrupt();
}
}
}
private AccessToken parseToken(String token) throws IOException {
String[] parts = token.split("\\.");
if (parts.length < 2) {
throw new IllegalArgumentException("Invalid token format.");
}
String claims = new String(Base64.getUrlDecoder().decode(parts[1]));
// Suppressing unchecked cast as we're handling a raw type from ObjectMapper
@SuppressWarnings("unchecked")
Map parsedClaims = objectMapper.readValue(claims, Map.class);
Object expiryToken = parsedClaims.get("exp");
Object iat = parsedClaims.get("iat");
if (expiryToken instanceof Number && iat instanceof Number) {
long ttl = (((Number) expiryToken).longValue() - ((Number) iat).longValue()) * 1000;
if (ttl <= 0) {
throw new IllegalArgumentException("Token 'iat' > 'exp'");
}
long expiry = System.currentTimeMillis() + ttl;
return new AccessToken(expiry, ttl, token);
} else {
throw new IllegalArgumentException("Unsupported access token format.");
}
}
/**
* Returns gRPC call credentials
*
* @return gRPC CallCredentials
* @throws StatusRuntimeException throws exception if token is unauthenticated or expired
*/
public CallCredentials getCallCredentials() throws StatusRuntimeException {
if (!isTokenValid()) {
log.info("{}; Starting a call with invalid token.", authIdentifier);
unsafeScheduleRefresh(0, false);
}
if (!isTokenValid()) {
Throwable lastError = refreshError.get();
if (lastError != null) {
if (lastError instanceof StatusRuntimeException
|| lastError instanceof StatusException) {
throw (RuntimeException) lastError;
} else {
throw Status.UNAUTHENTICATED
.withDescription(lastError.getMessage())
.asRuntimeException();
}
} else {
throw Status.UNAUTHENTICATED
.withDescription(
getAccessToken() == null
? "Access token not fetched."
: "Access token has expired.")
.asRuntimeException();
}
}
if (getAccessToken() != null) {
return new BearerTokenCallCredentials(getAccessToken().token);
}
throw new IllegalStateException("Access token has expired.");
}
private boolean isTokenValid() {
AccessToken token = getAccessToken();
boolean tokenValid = token != null && !token.hasExpired();
log.debug("{}; tokenValid: {}, token: {}.", authIdentifier, tokenValid, token);
return tokenValid;
}
private static class AccessToken {
private final long expiry;
private final long ttl;
private final String token;
public AccessToken(long expiry, long ttl, String token) {
this.expiry = expiry;
this.ttl = ttl;
this.token = token;
}
public boolean hasExpired() {
boolean hasExpired = System.currentTimeMillis() > expiry;
return hasExpired;
}
@Override
public String toString() {
return "AccessToken{"
+ "expiry="
+ expiry
+ ", ttl="
+ ttl
+ ", token='"
+ token
+ '\''
+ '}';
}
}
/** Internal class maintains auth token status */
public static class TokenStatus {
private final Throwable error;
private final boolean isValid;
private TokenStatus() {
this.isValid = true;
this.error = null;
}
private TokenStatus(Throwable error) {
this.isValid = false;
this.error = error;
}
public boolean isValid() {
return isValid;
}
private Throwable getError() {
return error;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy