All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.aad.msal4j.TokenCache Maven / Gradle / Ivy

There is a newer version: 1.0.15
Show newest version
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.nimbusds.jwt.JWTParser;

import java.io.IOException;
import java.text.ParseException;
import java.util.*;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/**
 * Cache used for storing tokens. For more details, see https://aka.ms/msal4j-token-cache
 * 

* Conditionally thread-safe */ public class TokenCache implements ITokenCache { protected static final int MIN_ACCESS_TOKEN_EXPIRE_IN_SEC = 5 * 60; transient private ReadWriteLock lock = new ReentrantReadWriteLock(); /** * Constructor for token cache * * @param tokenCacheAccessAspect {@link ITokenCacheAccessAspect} */ public TokenCache(ITokenCacheAccessAspect tokenCacheAccessAspect) { this(); this.tokenCacheAccessAspect = tokenCacheAccessAspect; } /** * Constructor for token cache */ public TokenCache() { } @JsonProperty("AccessToken") Map accessTokens = new LinkedHashMap<>(); @JsonProperty("RefreshToken") Map refreshTokens = new LinkedHashMap<>(); @JsonProperty("IdToken") Map idTokens = new LinkedHashMap<>(); @JsonProperty("Account") Map accounts = new LinkedHashMap<>(); @JsonProperty("AppMetadata") Map appMetadata = new LinkedHashMap<>(); transient ITokenCacheAccessAspect tokenCacheAccessAspect; private transient String serializedCachedSnapshot; @Override public void deserialize(String data) { if (StringHelper.isBlank(data)) { return; } serializedCachedSnapshot = data; TokenCache deserializedCache = JsonHelper.convertJsonToObject(data, TokenCache.class); lock.writeLock().lock(); try { this.accounts = deserializedCache.accounts; this.accessTokens = deserializedCache.accessTokens; this.refreshTokens = deserializedCache.refreshTokens; this.idTokens = deserializedCache.idTokens; this.appMetadata = deserializedCache.appMetadata; } finally { lock.writeLock().unlock(); } } private static void mergeJsonObjects(JsonNode old, JsonNode update) { mergeRemovals(old, update); mergeUpdates(old, update); } private static void mergeUpdates(JsonNode old, JsonNode update) { Iterator fieldNames = update.fieldNames(); while (fieldNames.hasNext()) { String uKey = fieldNames.next(); JsonNode uValue = update.get(uKey); // add new property if (!old.has(uKey)) { if (!uValue.isNull() && !(uValue.isObject() && uValue.size() == 0)) { ((ObjectNode) old).set(uKey, uValue); } } // merge old and new property else { JsonNode oValue = old.get(uKey); if (uValue.isObject()) { mergeUpdates(oValue, uValue); } else { ((ObjectNode) old).set(uKey, uValue); } } } } private static void mergeRemovals(JsonNode old, JsonNode update) { Set msalEntities = new HashSet<>(Arrays.asList("Account", "AccessToken", "RefreshToken", "IdToken", "AppMetadata")); for (String msalEntity : msalEntities) { JsonNode oldEntries = old.get(msalEntity); JsonNode newEntries = update.get(msalEntity); if (oldEntries != null) { Iterator> iterator = oldEntries.fields(); while (iterator.hasNext()) { Map.Entry oEntry = iterator.next(); String key = oEntry.getKey(); if (newEntries == null || !newEntries.has(key)) { iterator.remove(); } } } } } @Override public String serialize() { lock.readLock().lock(); try { if (!StringHelper.isBlank(serializedCachedSnapshot)) { JsonNode cache = JsonHelper.mapper.readTree(serializedCachedSnapshot); JsonNode update = JsonHelper.mapper.valueToTree(this); mergeJsonObjects(cache, update); return JsonHelper.mapper.writeValueAsString(cache); } return JsonHelper.mapper.writeValueAsString(this); } catch (IOException e) { throw new MsalClientException(e); } finally { lock.readLock().unlock(); } } private class CacheAspect implements AutoCloseable { ITokenCacheAccessContext context; CacheAspect(ITokenCacheAccessContext context) { if (tokenCacheAccessAspect != null) { this.context = context; tokenCacheAccessAspect.beforeCacheAccess(context); } } @Override public void close() { if (tokenCacheAccessAspect != null) { tokenCacheAccessAspect.afterCacheAccess(context); } } } void saveTokens(TokenRequestExecutor tokenRequestExecutor, AuthenticationResult authenticationResult, String environment) { try (CacheAspect cacheAspect = new CacheAspect( TokenCacheAccessContext.builder(). clientId(tokenRequestExecutor.getMsalRequest().application().clientId()). tokenCache(this). hasCacheChanged(true).build())) { try { lock.writeLock().lock(); if (!StringHelper.isBlank(authenticationResult.accessToken())) { AccessTokenCacheEntity atEntity = createAccessTokenCacheEntity (tokenRequestExecutor, authenticationResult, environment); accessTokens.put(atEntity.getKey(), atEntity); } if (!StringHelper.isBlank(authenticationResult.familyId())) { AppMetadataCacheEntity appMetadataCacheEntity = createAppMetadataCacheEntity(tokenRequestExecutor, authenticationResult, environment); appMetadata.put(appMetadataCacheEntity.getKey(), appMetadataCacheEntity); } if (!StringHelper.isBlank(authenticationResult.refreshToken())) { RefreshTokenCacheEntity rtEntity = createRefreshTokenCacheEntity (tokenRequestExecutor, authenticationResult, environment); rtEntity.family_id(authenticationResult.familyId()); refreshTokens.put(rtEntity.getKey(), rtEntity); } if (!StringHelper.isBlank(authenticationResult.idToken())) { IdTokenCacheEntity idTokenEntity = createIdTokenCacheEntity (tokenRequestExecutor, authenticationResult, environment); idTokens.put(idTokenEntity.getKey(), idTokenEntity); AccountCacheEntity accountCacheEntity = authenticationResult.accountCacheEntity(); if(accountCacheEntity!=null) { accountCacheEntity.environment(environment); accounts.put(accountCacheEntity.getKey(), accountCacheEntity); } } } finally { lock.writeLock().unlock(); } } } private static RefreshTokenCacheEntity createRefreshTokenCacheEntity(TokenRequestExecutor tokenRequestExecutor, AuthenticationResult authenticationResult, String environmentAlias) { RefreshTokenCacheEntity rt = new RefreshTokenCacheEntity(); rt.credentialType(CredentialTypeEnum.REFRESH_TOKEN.value()); if (authenticationResult.account() != null) { rt.homeAccountId(authenticationResult.account().homeAccountId()); } rt.environment(environmentAlias); rt.clientId(tokenRequestExecutor.getMsalRequest().application().clientId()); rt.secret(authenticationResult.refreshToken()); if (tokenRequestExecutor.getMsalRequest() instanceof OnBehalfOfRequest) { OnBehalfOfRequest onBehalfOfRequest = (OnBehalfOfRequest) tokenRequestExecutor.getMsalRequest(); rt.userAssertionHash(onBehalfOfRequest.parameters.userAssertion().getAssertionHash()); } return rt; } private static AccessTokenCacheEntity createAccessTokenCacheEntity(TokenRequestExecutor tokenRequestExecutor, AuthenticationResult authenticationResult, String environmentAlias) { AccessTokenCacheEntity at = new AccessTokenCacheEntity(); at.credentialType(CredentialTypeEnum.ACCESS_TOKEN.value()); if (authenticationResult.account() != null) { at.homeAccountId(authenticationResult.account().homeAccountId()); } at.environment(environmentAlias); at.clientId(tokenRequestExecutor.getMsalRequest().application().clientId()); at.secret(authenticationResult.accessToken()); at.realm(tokenRequestExecutor.requestAuthority.tenant()); String scopes = !StringHelper.isBlank(authenticationResult.scopes()) ? authenticationResult.scopes() : tokenRequestExecutor.getMsalRequest().msalAuthorizationGrant().getScopes(); at.target(scopes); if (tokenRequestExecutor.getMsalRequest() instanceof OnBehalfOfRequest) { OnBehalfOfRequest onBehalfOfRequest = (OnBehalfOfRequest) tokenRequestExecutor.getMsalRequest(); at.userAssertionHash(onBehalfOfRequest.parameters.userAssertion().getAssertionHash()); } long currTimestampSec = System.currentTimeMillis() / 1000; at.cachedAt(Long.toString(currTimestampSec)); at.expiresOn(Long.toString(authenticationResult.expiresOn())); if (authenticationResult.refreshOn() > 0) { at.refreshOn(Long.toString(authenticationResult.refreshOn())); } if (authenticationResult.extExpiresOn() > 0) { at.extExpiresOn(Long.toString(authenticationResult.extExpiresOn())); } return at; } private static IdTokenCacheEntity createIdTokenCacheEntity(TokenRequestExecutor tokenRequestExecutor, AuthenticationResult authenticationResult, String environmentAlias) { IdTokenCacheEntity idToken = new IdTokenCacheEntity(); idToken.credentialType(CredentialTypeEnum.ID_TOKEN.value()); if (authenticationResult.account() != null) { idToken.homeAccountId(authenticationResult.account().homeAccountId()); } idToken.environment(environmentAlias); idToken.clientId(tokenRequestExecutor.getMsalRequest().application().clientId()); idToken.secret(authenticationResult.idToken()); idToken.realm(tokenRequestExecutor.requestAuthority.tenant()); if (tokenRequestExecutor.getMsalRequest() instanceof OnBehalfOfRequest) { OnBehalfOfRequest onBehalfOfRequest = (OnBehalfOfRequest) tokenRequestExecutor.getMsalRequest(); idToken.userAssertionHash(onBehalfOfRequest.parameters.userAssertion().getAssertionHash()); } return idToken; } private static AppMetadataCacheEntity createAppMetadataCacheEntity(TokenRequestExecutor tokenRequestExecutor, AuthenticationResult authenticationResult, String environmentAlias) { AppMetadataCacheEntity appMetadataCacheEntity = new AppMetadataCacheEntity(); appMetadataCacheEntity.clientId(tokenRequestExecutor.getMsalRequest().application().clientId()); appMetadataCacheEntity.environment(environmentAlias); appMetadataCacheEntity.familyId(authenticationResult.familyId()); return appMetadataCacheEntity; } Set getAccounts(String clientId) { try (CacheAspect cacheAspect = new CacheAspect( TokenCacheAccessContext.builder(). clientId(clientId). tokenCache(this). build())) { try { lock.readLock().lock(); Map rootAccounts = new HashMap<>(); for (AccountCacheEntity accCached : accounts.values()) { IdTokenCacheEntity idToken = idTokens.get(getIdTokenKey( accCached.homeAccountId(), accCached.environment(), clientId, accCached.realm())); ITenantProfile profile = null; if (idToken != null) { Map idTokenClaims = JWTParser.parse(idToken.secret()).getJWTClaimsSet().getClaims(); profile = new TenantProfile(idTokenClaims, accCached.environment()); } if (rootAccounts.get(accCached.homeAccountId()) == null) { IAccount acc = accCached.toAccount(); ((Account) acc).tenantProfiles = new HashMap<>(); rootAccounts.put(accCached.homeAccountId(), acc); } if (profile != null) { ((Account) rootAccounts.get(accCached.homeAccountId())).tenantProfiles.put(accCached.realm(), profile); } if (accCached.homeAccountId().contains(accCached.localAccountId())) { ((Account) rootAccounts.get(accCached.homeAccountId())).username(accCached.username()); } } return new HashSet<>(rootAccounts.values()); } catch (ParseException e) { throw new MsalClientException("Cached JWT could not be parsed: " + e.getMessage(), AuthenticationErrorCode.INVALID_JWT); } finally { lock.readLock().unlock(); } } } /** * Returns a String representing a key of a cached ID token, formatted in the same way as {@link IdTokenCacheEntity#getKey} * * @return String representing a possible key of a cached ID token */ private String getIdTokenKey(String homeAccountId, String environment, String clientId, String realm) { return String.join(Constants.CACHE_KEY_SEPARATOR, Arrays.asList(homeAccountId, environment, "idtoken", clientId, realm, "")).toLowerCase(); } /** * @return familyId status of application */ private String getApplicationFamilyId(String clientId, Set environmentAliases) { for (AppMetadataCacheEntity data : appMetadata.values()) { if (data.clientId().equals(clientId) && environmentAliases.contains(data.environment()) && !StringHelper.isBlank(data.familyId())) { return data.familyId(); } } return null; } /** * @return set of client ids which belong to the family */ private Set getFamilyClientIds(String familyId, Set environmentAliases) { return appMetadata.values().stream().filter (appMetadata -> environmentAliases.contains(appMetadata.environment()) && familyId.equals(appMetadata.familyId()) ).map(AppMetadataCacheEntity::clientId).collect(Collectors.toSet()); } /** * Remove all cache entities related to account, including account cache entity * * @param clientId client id * @param account account */ void removeAccount(String clientId, IAccount account) { try (CacheAspect cacheAspect = new CacheAspect( TokenCacheAccessContext.builder(). clientId(clientId). tokenCache(this). hasCacheChanged(true). build())) { try { lock.writeLock().lock(); removeAccount(account); } finally { lock.writeLock().unlock(); } } } private void removeAccount(IAccount account) { Predicate> credentialToRemovePredicate = e -> !StringHelper.isBlank(e.getValue().homeAccountId()) && !StringHelper.isBlank(e.getValue().environment()) && e.getValue().homeAccountId().equals(account.homeAccountId()); accessTokens.entrySet().removeIf(credentialToRemovePredicate); refreshTokens.entrySet().removeIf(credentialToRemovePredicate); idTokens.entrySet().removeIf(credentialToRemovePredicate); accounts.entrySet().removeIf( e -> !StringHelper.isBlank(e.getValue().homeAccountId()) && !StringHelper.isBlank(e.getValue().environment()) && e.getValue().homeAccountId().equals(account.homeAccountId())); } private boolean isMatchingScopes(AccessTokenCacheEntity accessTokenCacheEntity, Set scopes) { Set accessTokenCacheEntityScopes = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); accessTokenCacheEntityScopes.addAll (Arrays.asList(accessTokenCacheEntity.target().split(Constants.SCOPES_SEPARATOR))); return accessTokenCacheEntityScopes.containsAll(scopes); } private boolean userAssertionHashMatches(Credential credential, String userAssertionHash) { if (userAssertionHash == null) { return true; } return credential.userAssertionHash() != null && credential.userAssertionHash().equalsIgnoreCase(userAssertionHash); } private boolean userAssertionHashMatches(AccountCacheEntity accountCacheEntity, String userAssertionHash) { if (userAssertionHash == null) { return true; } return accountCacheEntity.userAssertionHash() != null && accountCacheEntity.userAssertionHash().equalsIgnoreCase(userAssertionHash); } private Optional getAccessTokenCacheEntity( IAccount account, Authority authority, Set scopes, String clientId, Set environmentAliases) { long currTimeStampSec = new Date().getTime() / 1000; return accessTokens.values().stream().filter( accessToken -> accessToken.homeAccountId.equals(account.homeAccountId()) && environmentAliases.contains(accessToken.environment) && Long.parseLong(accessToken.expiresOn()) > currTimeStampSec + MIN_ACCESS_TOKEN_EXPIRE_IN_SEC && accessToken.realm.equals(authority.tenant()) && accessToken.clientId.equals(clientId) && isMatchingScopes(accessToken, scopes) ).findAny(); } private Optional getApplicationAccessTokenCacheEntity( Authority authority, Set scopes, String clientId, Set environmentAliases, String userAssertionHash) { long currTimeStampSec = new Date().getTime() / 1000; return accessTokens.values().stream().filter( accessToken -> userAssertionHashMatches(accessToken, userAssertionHash) && environmentAliases.contains(accessToken.environment) && Long.parseLong(accessToken.expiresOn()) > currTimeStampSec + MIN_ACCESS_TOKEN_EXPIRE_IN_SEC && accessToken.realm.equals(authority.tenant()) && accessToken.clientId.equals(clientId) && isMatchingScopes(accessToken, scopes)) .findAny(); } private Optional getIdTokenCacheEntity( IAccount account, Authority authority, String clientId, Set environmentAliases) { return idTokens.values().stream().filter( idToken -> idToken.homeAccountId.equals(account.homeAccountId()) && environmentAliases.contains(idToken.environment) && idToken.realm.equals(authority.tenant()) && idToken.clientId.equals(clientId) ).findAny(); } private Optional getIdTokenCacheEntity( Authority authority, String clientId, Set environmentAliases, String userAssertionHash) { return idTokens.values().stream().filter( idToken -> userAssertionHashMatches(idToken, userAssertionHash) && environmentAliases.contains(idToken.environment) && idToken.realm.equals(authority.tenant()) && idToken.clientId.equals(clientId) ).findAny(); } private Optional getRefreshTokenCacheEntity( String clientId, Set environmentAliases, String userAssertionHash) { return refreshTokens.values().stream().filter( refreshToken -> userAssertionHashMatches(refreshToken, userAssertionHash) && environmentAliases.contains(refreshToken.environment) && refreshToken.clientId.equals(clientId) ).findAny(); } private Optional getRefreshTokenCacheEntity( IAccount account, String clientId, Set environmentAliases) { return refreshTokens.values().stream().filter( refreshToken -> refreshToken.homeAccountId.equals(account.homeAccountId()) && environmentAliases.contains(refreshToken.environment) && refreshToken.clientId.equals(clientId) ).findAny(); } private Optional getAccountCacheEntity( IAccount account, Set environmentAliases) { return accounts.values().stream().filter( acc -> acc.homeAccountId.equals(account.homeAccountId()) && environmentAliases.contains(acc.environment) ).findAny(); } private Optional getAccountCacheEntity( Set environmentAliases, String userAssertionHash) { return accounts.values().stream().filter( acc -> userAssertionHashMatches(acc, userAssertionHash) && environmentAliases.contains(acc.environment) ).findAny(); } private Optional getAnyFamilyRefreshTokenCacheEntity (IAccount account, Set environmentAliases) { return refreshTokens.values().stream().filter (refreshToken -> refreshToken.homeAccountId.equals(account.homeAccountId()) && environmentAliases.contains(refreshToken.environment) && refreshToken.isFamilyRT() ).findAny(); } AuthenticationResult getCachedAuthenticationResult( IAccount account, Authority authority, Set scopes, String clientId) { AuthenticationResult.AuthenticationResultBuilder builder = AuthenticationResult.builder(); Set environmentAliases = AadInstanceDiscoveryProvider.getAliases(account.environment()); try (CacheAspect cacheAspect = new CacheAspect( TokenCacheAccessContext.builder(). clientId(clientId). tokenCache(this). account(account). build())) { try { lock.readLock().lock(); Optional accountCacheEntity = getAccountCacheEntity(account, environmentAliases); Optional atCacheEntity = getAccessTokenCacheEntity(account, authority, scopes, clientId, environmentAliases); Optional idTokenCacheEntity = getIdTokenCacheEntity(account, authority, clientId, environmentAliases); Optional rtCacheEntity; if (!StringHelper.isBlank(getApplicationFamilyId(clientId, environmentAliases))) { rtCacheEntity = getAnyFamilyRefreshTokenCacheEntity(account, environmentAliases); if (!rtCacheEntity.isPresent()) { rtCacheEntity = getRefreshTokenCacheEntity(account, clientId, environmentAliases); } } else { rtCacheEntity = getRefreshTokenCacheEntity(account, clientId, environmentAliases); if (!rtCacheEntity.isPresent()) { rtCacheEntity = getAnyFamilyRefreshTokenCacheEntity(account, environmentAliases); } } if (atCacheEntity.isPresent()) { builder. environment(atCacheEntity.get().environment). accessToken(atCacheEntity.get().secret). expiresOn(Long.parseLong(atCacheEntity.get().expiresOn())); if (atCacheEntity.get().refreshOn() != null) { builder.refreshOn(Long.parseLong(atCacheEntity.get().refreshOn())); } } else { builder.environment(authority.host()); } idTokenCacheEntity.ifPresent(tokenCacheEntity -> builder.idToken(tokenCacheEntity.secret)); rtCacheEntity.ifPresent(refreshTokenCacheEntity -> builder.refreshToken(refreshTokenCacheEntity.secret)); accountCacheEntity.ifPresent(builder::accountCacheEntity); } finally { lock.readLock().unlock(); } } return builder.build(); } AuthenticationResult getCachedAuthenticationResult( Authority authority, Set scopes, String clientId, IUserAssertion assertion) { AuthenticationResult.AuthenticationResultBuilder builder = AuthenticationResult.builder(); Set environmentAliases = AadInstanceDiscoveryProvider.getAliases(authority.host); builder.environment(authority.host()); try (CacheAspect cacheAspect = new CacheAspect( TokenCacheAccessContext.builder(). clientId(clientId). tokenCache(this). build())) { try { lock.readLock().lock(); String userAssertionHash = assertion == null ? null : assertion.getAssertionHash(); Optional accountCacheEntity = getAccountCacheEntity(environmentAliases, userAssertionHash); accountCacheEntity.ifPresent(builder::accountCacheEntity); Optional atCacheEntity = getApplicationAccessTokenCacheEntity(authority, scopes, clientId, environmentAliases, userAssertionHash); if (atCacheEntity.isPresent()) { builder. accessToken(atCacheEntity.get().secret). expiresOn(Long.parseLong(atCacheEntity.get().expiresOn())); if (atCacheEntity.get().refreshOn() != null) { builder.refreshOn(Long.parseLong(atCacheEntity.get().refreshOn())); } } Optional idTokenCacheEntity = getIdTokenCacheEntity(authority, clientId, environmentAliases, userAssertionHash); idTokenCacheEntity.ifPresent(tokenCacheEntity -> builder.idToken(tokenCacheEntity.secret)); Optional rtCacheEntity = getRefreshTokenCacheEntity(clientId, environmentAliases, userAssertionHash); rtCacheEntity.ifPresent(refreshTokenCacheEntity -> builder.refreshToken(refreshTokenCacheEntity.secret)); } finally { lock.readLock().unlock(); } return builder.build(); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy