com.microsoft.aad.msal4j.TokenCache Maven / Gradle / Ivy
Show all versions of msal4j Show documentation
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.microsoft.aad.msal4j;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.nimbusds.jwt.JWTParser;
import java.io.IOException;
import java.text.ParseException;
import java.util.*;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Predicate;
import java.util.stream.Collectors;
/**
* Cache used for storing tokens. For more details, see https://aka.ms/msal4j-token-cache
*
* Conditionally thread-safe
*/
public class TokenCache implements ITokenCache {
protected static final int MIN_ACCESS_TOKEN_EXPIRE_IN_SEC = 5 * 60;
transient private ReadWriteLock lock = new ReentrantReadWriteLock();
/**
* Constructor for token cache
*
* @param tokenCacheAccessAspect {@link ITokenCacheAccessAspect}
*/
public TokenCache(ITokenCacheAccessAspect tokenCacheAccessAspect) {
this();
this.tokenCacheAccessAspect = tokenCacheAccessAspect;
}
/**
* Constructor for token cache
*/
public TokenCache() {
}
@JsonProperty("AccessToken")
Map accessTokens = new LinkedHashMap<>();
@JsonProperty("RefreshToken")
Map refreshTokens = new LinkedHashMap<>();
@JsonProperty("IdToken")
Map idTokens = new LinkedHashMap<>();
@JsonProperty("Account")
Map accounts = new LinkedHashMap<>();
@JsonProperty("AppMetadata")
Map appMetadata = new LinkedHashMap<>();
transient ITokenCacheAccessAspect tokenCacheAccessAspect;
private transient String serializedCachedSnapshot;
@Override
public void deserialize(String data) {
if (StringHelper.isBlank(data)) {
return;
}
serializedCachedSnapshot = data;
TokenCache deserializedCache = JsonHelper.convertJsonToObject(data, TokenCache.class);
lock.writeLock().lock();
try {
this.accounts = deserializedCache.accounts;
this.accessTokens = deserializedCache.accessTokens;
this.refreshTokens = deserializedCache.refreshTokens;
this.idTokens = deserializedCache.idTokens;
this.appMetadata = deserializedCache.appMetadata;
} finally {
lock.writeLock().unlock();
}
}
private static void mergeJsonObjects(JsonNode old, JsonNode update) {
mergeRemovals(old, update);
mergeUpdates(old, update);
}
private static void mergeUpdates(JsonNode old, JsonNode update) {
Iterator fieldNames = update.fieldNames();
while (fieldNames.hasNext()) {
String uKey = fieldNames.next();
JsonNode uValue = update.get(uKey);
// add new property
if (!old.has(uKey)) {
if (!uValue.isNull() &&
!(uValue.isObject() && uValue.size() == 0)) {
((ObjectNode) old).set(uKey, uValue);
}
}
// merge old and new property
else {
JsonNode oValue = old.get(uKey);
if (uValue.isObject()) {
mergeUpdates(oValue, uValue);
} else {
((ObjectNode) old).set(uKey, uValue);
}
}
}
}
private static void mergeRemovals(JsonNode old, JsonNode update) {
Set msalEntities =
new HashSet<>(Arrays.asList("Account", "AccessToken", "RefreshToken", "IdToken", "AppMetadata"));
for (String msalEntity : msalEntities) {
JsonNode oldEntries = old.get(msalEntity);
JsonNode newEntries = update.get(msalEntity);
if (oldEntries != null) {
Iterator> iterator = oldEntries.fields();
while (iterator.hasNext()) {
Map.Entry oEntry = iterator.next();
String key = oEntry.getKey();
if (newEntries == null || !newEntries.has(key)) {
iterator.remove();
}
}
}
}
}
@Override
public String serialize() {
lock.readLock().lock();
try {
if (!StringHelper.isBlank(serializedCachedSnapshot)) {
JsonNode cache = JsonHelper.mapper.readTree(serializedCachedSnapshot);
JsonNode update = JsonHelper.mapper.valueToTree(this);
mergeJsonObjects(cache, update);
return JsonHelper.mapper.writeValueAsString(cache);
}
return JsonHelper.mapper.writeValueAsString(this);
} catch (IOException e) {
throw new MsalClientException(e);
} finally {
lock.readLock().unlock();
}
}
private class CacheAspect implements AutoCloseable {
ITokenCacheAccessContext context;
CacheAspect(ITokenCacheAccessContext context) {
if (tokenCacheAccessAspect != null) {
this.context = context;
tokenCacheAccessAspect.beforeCacheAccess(context);
}
}
@Override
public void close() {
if (tokenCacheAccessAspect != null) {
tokenCacheAccessAspect.afterCacheAccess(context);
}
}
}
void saveTokens(TokenRequestExecutor tokenRequestExecutor, AuthenticationResult authenticationResult, String environment) {
try (CacheAspect cacheAspect = new CacheAspect(
TokenCacheAccessContext.builder().
clientId(tokenRequestExecutor.getMsalRequest().application().clientId()).
tokenCache(this).
hasCacheChanged(true).build())) {
try {
lock.writeLock().lock();
if (!StringHelper.isBlank(authenticationResult.accessToken())) {
AccessTokenCacheEntity atEntity = createAccessTokenCacheEntity
(tokenRequestExecutor, authenticationResult, environment);
accessTokens.put(atEntity.getKey(), atEntity);
}
if (!StringHelper.isBlank(authenticationResult.familyId())) {
AppMetadataCacheEntity appMetadataCacheEntity =
createAppMetadataCacheEntity(tokenRequestExecutor, authenticationResult, environment);
appMetadata.put(appMetadataCacheEntity.getKey(), appMetadataCacheEntity);
}
if (!StringHelper.isBlank(authenticationResult.refreshToken())) {
RefreshTokenCacheEntity rtEntity = createRefreshTokenCacheEntity
(tokenRequestExecutor, authenticationResult, environment);
rtEntity.family_id(authenticationResult.familyId());
refreshTokens.put(rtEntity.getKey(), rtEntity);
}
if (!StringHelper.isBlank(authenticationResult.idToken())) {
IdTokenCacheEntity idTokenEntity = createIdTokenCacheEntity
(tokenRequestExecutor, authenticationResult, environment);
idTokens.put(idTokenEntity.getKey(), idTokenEntity);
AccountCacheEntity accountCacheEntity = authenticationResult.accountCacheEntity();
if(accountCacheEntity!=null) {
accountCacheEntity.environment(environment);
accounts.put(accountCacheEntity.getKey(), accountCacheEntity);
}
}
} finally {
lock.writeLock().unlock();
}
}
}
private static RefreshTokenCacheEntity createRefreshTokenCacheEntity(TokenRequestExecutor tokenRequestExecutor,
AuthenticationResult authenticationResult,
String environmentAlias) {
RefreshTokenCacheEntity rt = new RefreshTokenCacheEntity();
rt.credentialType(CredentialTypeEnum.REFRESH_TOKEN.value());
if (authenticationResult.account() != null) {
rt.homeAccountId(authenticationResult.account().homeAccountId());
}
rt.environment(environmentAlias);
rt.clientId(tokenRequestExecutor.getMsalRequest().application().clientId());
rt.secret(authenticationResult.refreshToken());
if (tokenRequestExecutor.getMsalRequest() instanceof OnBehalfOfRequest) {
OnBehalfOfRequest onBehalfOfRequest = (OnBehalfOfRequest) tokenRequestExecutor.getMsalRequest();
rt.userAssertionHash(onBehalfOfRequest.parameters.userAssertion().getAssertionHash());
}
return rt;
}
private static AccessTokenCacheEntity createAccessTokenCacheEntity(TokenRequestExecutor tokenRequestExecutor,
AuthenticationResult authenticationResult,
String environmentAlias) {
AccessTokenCacheEntity at = new AccessTokenCacheEntity();
at.credentialType(CredentialTypeEnum.ACCESS_TOKEN.value());
if (authenticationResult.account() != null) {
at.homeAccountId(authenticationResult.account().homeAccountId());
}
at.environment(environmentAlias);
at.clientId(tokenRequestExecutor.getMsalRequest().application().clientId());
at.secret(authenticationResult.accessToken());
at.realm(tokenRequestExecutor.tenant);
String scopes = !StringHelper.isBlank(authenticationResult.scopes()) ? authenticationResult.scopes() :
tokenRequestExecutor.getMsalRequest().msalAuthorizationGrant().getScopes();
at.target(scopes);
if (tokenRequestExecutor.getMsalRequest() instanceof OnBehalfOfRequest) {
OnBehalfOfRequest onBehalfOfRequest = (OnBehalfOfRequest) tokenRequestExecutor.getMsalRequest();
at.userAssertionHash(onBehalfOfRequest.parameters.userAssertion().getAssertionHash());
}
long currTimestampSec = System.currentTimeMillis() / 1000;
at.cachedAt(Long.toString(currTimestampSec));
at.expiresOn(Long.toString(authenticationResult.expiresOn()));
if (authenticationResult.refreshOn() > 0) {
at.refreshOn(Long.toString(authenticationResult.refreshOn()));
}
if (authenticationResult.extExpiresOn() > 0) {
at.extExpiresOn(Long.toString(authenticationResult.extExpiresOn()));
}
return at;
}
private static IdTokenCacheEntity createIdTokenCacheEntity(TokenRequestExecutor tokenRequestExecutor,
AuthenticationResult authenticationResult,
String environmentAlias) {
IdTokenCacheEntity idToken = new IdTokenCacheEntity();
idToken.credentialType(CredentialTypeEnum.ID_TOKEN.value());
if (authenticationResult.account() != null) {
idToken.homeAccountId(authenticationResult.account().homeAccountId());
}
idToken.environment(environmentAlias);
idToken.clientId(tokenRequestExecutor.getMsalRequest().application().clientId());
idToken.secret(authenticationResult.idToken());
idToken.realm(tokenRequestExecutor.tenant);
if (tokenRequestExecutor.getMsalRequest() instanceof OnBehalfOfRequest) {
OnBehalfOfRequest onBehalfOfRequest = (OnBehalfOfRequest) tokenRequestExecutor.getMsalRequest();
idToken.userAssertionHash(onBehalfOfRequest.parameters.userAssertion().getAssertionHash());
}
return idToken;
}
private static AppMetadataCacheEntity createAppMetadataCacheEntity(TokenRequestExecutor tokenRequestExecutor,
AuthenticationResult authenticationResult,
String environmentAlias) {
AppMetadataCacheEntity appMetadataCacheEntity = new AppMetadataCacheEntity();
appMetadataCacheEntity.clientId(tokenRequestExecutor.getMsalRequest().application().clientId());
appMetadataCacheEntity.environment(environmentAlias);
appMetadataCacheEntity.familyId(authenticationResult.familyId());
return appMetadataCacheEntity;
}
Set getAccounts(String clientId) {
try (CacheAspect cacheAspect = new CacheAspect(
TokenCacheAccessContext.builder().
clientId(clientId).
tokenCache(this).
build())) {
try {
lock.readLock().lock();
Map rootAccounts = new HashMap<>();
for (AccountCacheEntity accCached : accounts.values()) {
IdTokenCacheEntity idToken = idTokens.get(getIdTokenKey(
accCached.homeAccountId(),
accCached.environment(),
clientId,
accCached.realm()));
ITenantProfile profile = null;
if (idToken != null) {
Map idTokenClaims = JWTParser.parse(idToken.secret()).getJWTClaimsSet().getClaims();
profile = new TenantProfile(idTokenClaims, accCached.environment());
}
if (rootAccounts.get(accCached.homeAccountId()) == null) {
IAccount acc = accCached.toAccount();
((Account) acc).tenantProfiles = new HashMap<>();
rootAccounts.put(accCached.homeAccountId(), acc);
}
if (profile != null) {
((Account) rootAccounts.get(accCached.homeAccountId())).tenantProfiles.put(accCached.realm(), profile);
}
if (accCached.localAccountId() != null && accCached.homeAccountId().contains(accCached.localAccountId())) {
((Account) rootAccounts.get(accCached.homeAccountId())).username(accCached.username());
}
}
return new HashSet<>(rootAccounts.values());
} catch (ParseException e) {
throw new MsalClientException("Cached JWT could not be parsed: " + e.getMessage(), AuthenticationErrorCode.INVALID_JWT);
} finally {
lock.readLock().unlock();
}
}
}
/**
* Returns a String representing a key of a cached ID token, formatted in the same way as {@link IdTokenCacheEntity#getKey}
*
* @return String representing a possible key of a cached ID token
*/
private String getIdTokenKey(String homeAccountId, String environment, String clientId, String realm) {
return String.join(Constants.CACHE_KEY_SEPARATOR,
Arrays.asList(homeAccountId,
environment,
"idtoken", clientId,
realm, "")).toLowerCase();
}
/**
* @return familyId status of application
*/
private String getApplicationFamilyId(String clientId, Set environmentAliases) {
for (AppMetadataCacheEntity data : appMetadata.values()) {
if (data.clientId().equals(clientId) &&
environmentAliases.contains(data.environment()) &&
!StringHelper.isBlank(data.familyId())) {
return data.familyId();
}
}
return null;
}
/**
* @return set of client ids which belong to the family
*/
private Set getFamilyClientIds(String familyId, Set environmentAliases) {
return appMetadata.values().stream().filter
(appMetadata -> environmentAliases.contains(appMetadata.environment()) &&
familyId.equals(appMetadata.familyId())
).map(AppMetadataCacheEntity::clientId).collect(Collectors.toSet());
}
/**
* Remove all cache entities related to account, including account cache entity
*
* @param clientId client id
* @param account account
*/
void removeAccount(String clientId, IAccount account) {
try (CacheAspect cacheAspect = new CacheAspect(
TokenCacheAccessContext.builder().
clientId(clientId).
tokenCache(this).
hasCacheChanged(true).
build())) {
try {
lock.writeLock().lock();
removeAccount(account);
} finally {
lock.writeLock().unlock();
}
}
}
private void removeAccount(IAccount account) {
Predicate> credentialToRemovePredicate =
e -> !StringHelper.isBlank(e.getValue().homeAccountId()) &&
!StringHelper.isBlank(e.getValue().environment()) &&
e.getValue().homeAccountId().equals(account.homeAccountId());
accessTokens.entrySet().removeIf(credentialToRemovePredicate);
refreshTokens.entrySet().removeIf(credentialToRemovePredicate);
idTokens.entrySet().removeIf(credentialToRemovePredicate);
accounts.entrySet().removeIf(
e -> !StringHelper.isBlank(e.getValue().homeAccountId()) &&
!StringHelper.isBlank(e.getValue().environment()) &&
e.getValue().homeAccountId().equals(account.homeAccountId()));
}
private boolean isMatchingScopes(AccessTokenCacheEntity accessTokenCacheEntity, Set scopes) {
Set accessTokenCacheEntityScopes = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
accessTokenCacheEntityScopes.addAll
(Arrays.asList(accessTokenCacheEntity.target().split(Constants.SCOPES_SEPARATOR)));
return accessTokenCacheEntityScopes.containsAll(scopes);
}
private boolean userAssertionHashMatches(Credential credential, String userAssertionHash) {
if (userAssertionHash == null) {
return true;
}
return credential.userAssertionHash() != null &&
credential.userAssertionHash().equalsIgnoreCase(userAssertionHash);
}
private boolean userAssertionHashMatches(AccountCacheEntity accountCacheEntity, String userAssertionHash) {
if (userAssertionHash == null) {
return true;
}
return accountCacheEntity.userAssertionHash() != null &&
accountCacheEntity.userAssertionHash().equalsIgnoreCase(userAssertionHash);
}
private Optional getAccessTokenCacheEntity(
IAccount account,
Authority authority,
Set scopes,
String clientId,
Set environmentAliases) {
long currTimeStampSec = new Date().getTime() / 1000;
return accessTokens.values().stream().filter(
accessToken ->
accessToken.homeAccountId.equals(account.homeAccountId()) &&
environmentAliases.contains(accessToken.environment) &&
Long.parseLong(accessToken.expiresOn()) > currTimeStampSec + MIN_ACCESS_TOKEN_EXPIRE_IN_SEC &&
accessToken.realm.equals(authority.tenant()) &&
accessToken.clientId.equals(clientId) &&
isMatchingScopes(accessToken, scopes)
).findAny();
}
private Optional getApplicationAccessTokenCacheEntity(
Authority authority,
Set scopes,
String clientId,
Set environmentAliases,
String userAssertionHash) {
long currTimeStampSec = new Date().getTime() / 1000;
return accessTokens.values().stream().filter(
accessToken ->
userAssertionHashMatches(accessToken, userAssertionHash) &&
environmentAliases.contains(accessToken.environment) &&
Long.parseLong(accessToken.expiresOn()) > currTimeStampSec + MIN_ACCESS_TOKEN_EXPIRE_IN_SEC &&
accessToken.realm.equals(authority.tenant()) &&
accessToken.clientId.equals(clientId) &&
isMatchingScopes(accessToken, scopes))
.findAny();
}
private Optional getIdTokenCacheEntity(
IAccount account,
Authority authority,
String clientId,
Set environmentAliases) {
return idTokens.values().stream().filter(
idToken ->
idToken.homeAccountId.equals(account.homeAccountId()) &&
environmentAliases.contains(idToken.environment) &&
idToken.realm.equals(authority.tenant()) &&
idToken.clientId.equals(clientId)
).findAny();
}
private Optional getIdTokenCacheEntity(
Authority authority,
String clientId,
Set environmentAliases,
String userAssertionHash) {
return idTokens.values().stream().filter(
idToken ->
userAssertionHashMatches(idToken, userAssertionHash) &&
environmentAliases.contains(idToken.environment) &&
idToken.realm.equals(authority.tenant()) &&
idToken.clientId.equals(clientId)
).findAny();
}
private Optional getRefreshTokenCacheEntity(
String clientId,
Set environmentAliases,
String userAssertionHash) {
return refreshTokens.values().stream().filter(
refreshToken ->
userAssertionHashMatches(refreshToken, userAssertionHash) &&
environmentAliases.contains(refreshToken.environment) &&
refreshToken.clientId.equals(clientId)
).findAny();
}
private Optional getRefreshTokenCacheEntity(
IAccount account,
String clientId,
Set environmentAliases) {
return refreshTokens.values().stream().filter(
refreshToken ->
refreshToken.homeAccountId.equals(account.homeAccountId()) &&
environmentAliases.contains(refreshToken.environment) &&
refreshToken.clientId.equals(clientId)
).findAny();
}
private Optional getAccountCacheEntity(
IAccount account,
Set environmentAliases) {
return accounts.values().stream().filter(
acc ->
acc.homeAccountId.equals(account.homeAccountId()) &&
environmentAliases.contains(acc.environment)
).findAny();
}
private Optional getAccountCacheEntity(
Set environmentAliases,
String userAssertionHash) {
return accounts.values().stream().filter(
acc -> userAssertionHashMatches(acc, userAssertionHash) &&
environmentAliases.contains(acc.environment)
).findAny();
}
private Optional getAnyFamilyRefreshTokenCacheEntity
(IAccount account, Set environmentAliases) {
return refreshTokens.values().stream().filter
(refreshToken -> refreshToken.homeAccountId.equals(account.homeAccountId()) &&
environmentAliases.contains(refreshToken.environment) &&
refreshToken.isFamilyRT()
).findAny();
}
AuthenticationResult getCachedAuthenticationResult(
IAccount account,
Authority authority,
Set scopes,
String clientId) {
AuthenticationResult.AuthenticationResultBuilder builder = AuthenticationResult.builder();
Set environmentAliases = AadInstanceDiscoveryProvider.getAliases(account.environment());
try (CacheAspect cacheAspect = new CacheAspect(
TokenCacheAccessContext.builder().
clientId(clientId).
tokenCache(this).
account(account).
build())) {
try {
lock.readLock().lock();
Optional accountCacheEntity =
getAccountCacheEntity(account, environmentAliases);
Optional atCacheEntity =
getAccessTokenCacheEntity(account, authority, scopes, clientId, environmentAliases);
Optional idTokenCacheEntity =
getIdTokenCacheEntity(account, authority, clientId, environmentAliases);
Optional rtCacheEntity;
if (!StringHelper.isBlank(getApplicationFamilyId(clientId, environmentAliases))) {
rtCacheEntity = getAnyFamilyRefreshTokenCacheEntity(account, environmentAliases);
if (!rtCacheEntity.isPresent()) {
rtCacheEntity = getRefreshTokenCacheEntity(account, clientId, environmentAliases);
}
} else {
rtCacheEntity = getRefreshTokenCacheEntity(account, clientId, environmentAliases);
if (!rtCacheEntity.isPresent()) {
rtCacheEntity = getAnyFamilyRefreshTokenCacheEntity(account, environmentAliases);
}
}
if (atCacheEntity.isPresent()) {
builder.
environment(atCacheEntity.get().environment).
accessToken(atCacheEntity.get().secret).
expiresOn(Long.parseLong(atCacheEntity.get().expiresOn()));
if (atCacheEntity.get().refreshOn() != null) {
builder.refreshOn(Long.parseLong(atCacheEntity.get().refreshOn()));
}
} else {
builder.environment(authority.host());
}
idTokenCacheEntity.ifPresent(tokenCacheEntity -> builder.idToken(tokenCacheEntity.secret));
rtCacheEntity.ifPresent(refreshTokenCacheEntity ->
builder.refreshToken(refreshTokenCacheEntity.secret));
accountCacheEntity.ifPresent(builder::accountCacheEntity);
} finally {
lock.readLock().unlock();
}
}
return builder.build();
}
AuthenticationResult getCachedAuthenticationResult(
Authority authority,
Set scopes,
String clientId,
IUserAssertion assertion) {
AuthenticationResult.AuthenticationResultBuilder builder = AuthenticationResult.builder();
Set environmentAliases = AadInstanceDiscoveryProvider.getAliases(authority.host);
builder.environment(authority.host());
try (CacheAspect cacheAspect = new CacheAspect(
TokenCacheAccessContext.builder().
clientId(clientId).
tokenCache(this).
build())) {
try {
lock.readLock().lock();
String userAssertionHash = assertion == null ? null : assertion.getAssertionHash();
Optional accountCacheEntity =
getAccountCacheEntity(environmentAliases, userAssertionHash);
accountCacheEntity.ifPresent(builder::accountCacheEntity);
Optional atCacheEntity =
getApplicationAccessTokenCacheEntity(authority, scopes, clientId, environmentAliases, userAssertionHash);
if (atCacheEntity.isPresent()) {
builder.
accessToken(atCacheEntity.get().secret).
expiresOn(Long.parseLong(atCacheEntity.get().expiresOn()));
if (atCacheEntity.get().refreshOn() != null) {
builder.refreshOn(Long.parseLong(atCacheEntity.get().refreshOn()));
}
}
Optional idTokenCacheEntity =
getIdTokenCacheEntity(authority, clientId, environmentAliases, userAssertionHash);
idTokenCacheEntity.ifPresent(tokenCacheEntity -> builder.idToken(tokenCacheEntity.secret));
Optional rtCacheEntity = getRefreshTokenCacheEntity(clientId, environmentAliases, userAssertionHash);
rtCacheEntity.ifPresent(refreshTokenCacheEntity ->
builder.refreshToken(refreshTokenCacheEntity.secret));
} finally {
lock.readLock().unlock();
}
return builder.build();
}
}
}