com.uid2.shared.auth.AuthorizableStore Maven / Gradle / Ivy
package com.uid2.shared.auth;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.uid2.shared.secret.KeyHasher;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
public class AuthorizableStore {
private static final Logger LOGGER = LoggerFactory.getLogger(AuthorizableStore.class);
private static final Pattern KEY_PATTERN = Pattern.compile("(?:UID2|EUID)-[CO]-[LTIP]-([0-9]+)-.{6}\\..{38}");
private static final KeyHasher KEY_HASHER = new KeyHasher();
private static final int CACHE_MAX_SIZE = 100_000;
private final AtomicReference authorizables;
private final Cache keyToHashCache;
private final Counter keyToHashTotalCounter;
private final Counter keyToHashMissCounter;
public AuthorizableStore(Class cls) {
this.authorizables = new AtomicReference<>(new AuthorizableStoreSnapshot(new ArrayList<>()));
this.keyToHashCache = createCache();
String cacheName = cls.getName().toLowerCase();
keyToHashTotalCounter = Counter.builder("uid2.cache.total_count")
.description("counter for " + cacheName + " cache total count")
.tag("cache", cacheName)
.register(Metrics.globalRegistry);
keyToHashMissCounter = Counter.builder("uid2.cache.miss_count")
.description("counter for " + cacheName + " cache miss count")
.tag("cache", cacheName)
.register(Metrics.globalRegistry);
}
public void refresh(Collection authorizablesToRefresh) {
authorizables.set(new AuthorizableStoreSnapshot(authorizablesToRefresh));
invalidateInvalidKeys();
}
public T getAuthorizableByKey(String key) {
if (key == null) {
return null;
}
AuthorizableStoreSnapshot latest = authorizables.get();
String cachedHash = keyToHashCache.getIfPresent(key);
keyToHashTotalCounter.increment();
if (cachedHash != null) {
return cachedHash.isBlank() ? null : latest.getAuthorizableByHash(wrapHashToByteBuffer(cachedHash));
} else {
keyToHashMissCounter.increment();
}
Integer siteId = getSiteIdFromKey(key);
List salts = siteId == null ? latest.getSalts() : latest.getSaltsBySiteId(siteId);
T authorizable = getAuthorizable(key, salts, latest);
keyToHashCache.put(key, authorizable == null ? "" : authorizable.getKeyHash());
return authorizable;
}
public T getAuthorizableByHash(String hash) {
if (hash == null) {
return null;
}
ByteBuffer hashBytes = wrapHashToByteBuffer(hash);
if (hashBytes == null) {
return null;
}
AuthorizableStoreSnapshot latest = authorizables.get();
return latest.getAuthorizableByHash(hashBytes);
}
private T getAuthorizable(String key, List salts, AuthorizableStoreSnapshot snapshot) {
for (byte[] salt : salts) {
byte[] keyHash = KEY_HASHER.hashKey(key, salt);
T authorizable = snapshot.getAuthorizableByHash(ByteBuffer.wrap(keyHash));
if (authorizable != null) {
return authorizable;
}
}
return null;
}
private Cache createCache() {
return Caffeine.newBuilder()
.maximumSize(CACHE_MAX_SIZE)
.build();
}
private void invalidateInvalidKeys() {
List invalidKeys = keyToHashCache.asMap()
.entrySet()
.stream()
.filter(entry -> entry.getValue().isBlank())
.map(Map.Entry::getKey)
.collect(Collectors.toList());
invalidKeys.forEach(keyToHashCache::invalidate);
}
private ByteBuffer wrapHashToByteBuffer(String hash) {
byte[] hashBytes = convertBase64StringToBytes(hash);
return hashBytes == null ? null : ByteBuffer.wrap(hashBytes);
}
private byte[] convertBase64StringToBytes(String str) {
try {
return Base64.getDecoder().decode(str);
} catch (IllegalArgumentException e) {
LOGGER.error("Invalid base64 string: {}", str);
return null;
}
}
private static Integer getSiteIdFromKey(String key) {
Matcher matcher = KEY_PATTERN.matcher(key);
if (matcher.find()) {
return Integer.valueOf(matcher.group(1));
} else {
return null;
}
}
private class AuthorizableStoreSnapshot {
private final Map hashToAuthorizableMap;
private final Map> siteIdToSaltsMap;
private final List salts;
public AuthorizableStoreSnapshot(Collection authorizables) {
this.hashToAuthorizableMap = authorizables.stream()
.collect(Collectors.toMap(
a -> wrapHashToByteBuffer(a.getKeyHash()),
a -> a
));
this.siteIdToSaltsMap = authorizables.stream()
.filter(a -> a.getSiteId() != null)
.collect(Collectors.groupingBy(
IAuthorizable::getSiteId,
Collectors.mapping(a -> convertBase64StringToBytes(a.getKeySalt()), Collectors.toList())
));
this.salts = authorizables.stream()
.map(a -> convertBase64StringToBytes(a.getKeySalt()))
.collect(Collectors.toList());
}
public T getAuthorizableByHash(ByteBuffer hashBytes) {
return hashToAuthorizableMap.get(hashBytes);
}
public List getSaltsBySiteId(int siteId) {
return siteIdToSaltsMap.getOrDefault(siteId, List.of());
}
public List getSalts() {
return salts;
}
}
public Collection getAuthorizables() {
return authorizables.get().hashToAuthorizableMap.values();
}
}