All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.officefloor.web.jwt.JwtHttpSecuritySource Maven / Gradle / Ivy

package net.officefloor.web.jwt;

import java.io.IOException;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;

import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.impl.crypto.DefaultJwtSignatureValidator;
import io.jsonwebtoken.io.Decoder;
import net.officefloor.compile.properties.Property;
import net.officefloor.frame.api.build.None;
import net.officefloor.frame.api.clock.Clock;
import net.officefloor.frame.api.function.FlowCallback;
import net.officefloor.frame.api.managedobject.source.ManagedObjectStartupProcess;
import net.officefloor.frame.internal.structure.ManagedObjectScope;
import net.officefloor.plugin.managedobject.poll.StatePollContext;
import net.officefloor.plugin.managedobject.poll.StatePoller;
import net.officefloor.server.http.HttpException;
import net.officefloor.server.http.HttpRequest;
import net.officefloor.server.http.HttpResponse;
import net.officefloor.server.http.HttpStatus;
import net.officefloor.web.jwt.JwtClaimsManagedObjectSource.Dependencies;
import net.officefloor.web.jwt.role.JwtRoleCollector;
import net.officefloor.web.jwt.validate.JwtValidateKey;
import net.officefloor.web.jwt.validate.JwtValidateKeyCollector;
import net.officefloor.web.security.HttpAuthentication;
import net.officefloor.web.security.scheme.HttpAccessControlImpl;
import net.officefloor.web.security.scheme.HttpAuthenticationImpl;
import net.officefloor.web.security.scheme.HttpAuthenticationScheme;
import net.officefloor.web.spi.security.AuthenticateContext;
import net.officefloor.web.spi.security.AuthenticationContext;
import net.officefloor.web.spi.security.ChallengeContext;
import net.officefloor.web.spi.security.HttpSecurity;
import net.officefloor.web.spi.security.HttpSecurityContext;
import net.officefloor.web.spi.security.HttpSecurityExecuteContext;
import net.officefloor.web.spi.security.HttpSecuritySource;
import net.officefloor.web.spi.security.HttpSecuritySourceContext;
import net.officefloor.web.spi.security.HttpSecuritySupportingManagedObject;
import net.officefloor.web.spi.security.LogoutContext;
import net.officefloor.web.spi.security.RatifyContext;
import net.officefloor.web.spi.security.impl.AbstractHttpSecuritySource;
import net.officefloor.web.state.HttpRequestState;

/**
 * {@link HttpSecuritySource} for JWT.
 * 
 * @author Daniel Sagenschneider
 */
public class JwtHttpSecuritySource extends
		AbstractHttpSecuritySource, JwtHttpAccessControl, Void, None, JwtHttpSecuritySource.Flows>
		implements
		HttpSecurity, JwtHttpAccessControl, Void, None, JwtHttpSecuritySource.Flows> {

	/**
	 * Allows overriding the creation of {@link JwtValidateKey} instances.
	 */
	@FunctionalInterface
	public static interface JwtValidateKeysFactory {

		/**
		 * Obtains the {@link JwtValidateKey} instances to use.
		 * 
		 * @return {@link JwtValidateKey} instances to use.
		 * @throws Exception If fails to create the {@link JwtValidateKey} instances.
		 */
		JwtValidateKey[] createJwtValidateKeys() throws Exception;
	}

	/**
	 * {@link Runnable} within the context of the {@link JwtValidateKeysFactory}
	 */
	@FunctionalInterface
	public static interface ContextRunnable {

		/**
		 * Logic to run within the context of the {@link JwtValidateKeysFactory}.
		 */
		void run() throws T;
	}

	/**
	 * 

* Runs the {@link ContextRunnable} using the {@link JwtValidateKeysFactory}. *

* This is typically used for testing to allow overriding the * {@link JwtValidateKey} instances being used. * * @param validateKeysFactory {@link JwtValidateKeysFactory}. May be * null to not override. * @param runnable {@link ContextRunnable}. * @throws T If failure in {@link ContextRunnable}. */ public static void runWithKeys(JwtValidateKeysFactory validateKeysFactory, ContextRunnable runnable) throws T { // Initialise the overrides threadLocalKeysOverride.set(new JwtKeysFactoryOverride(validateKeysFactory)); try { // Undertake the logic runnable.run(); } finally { // Clear the keys override threadLocalKeysOverride.remove(); } } /** * {@link ThreadLocal} for the {@link JwtKeysFactoryOverride}. */ private static ThreadLocal threadLocalKeysOverride = new ThreadLocal<>(); /** * Possible override for creation of {@link JwtValidateKey} instances. */ private static class JwtKeysFactoryOverride { /** * {@link JwtValidateKeysFactory}. */ private final JwtValidateKeysFactory validateKeysFactory; /** * Instantiate. * * @param validateKeysFactory {@link JwtValidateKeysFactory}. */ private JwtKeysFactoryOverride(JwtValidateKeysFactory validateKeysFactory) { this.validateKeysFactory = validateKeysFactory; } } /** * Authentication scheme Bearer. */ public static final String AUTHENTICATION_SCHEME_BEARER = "Bearer"; /** * {@link Property} name for the claims {@link Class} to be loaded with claim * information of JWT. */ public static final String PROPERTY_CLAIMS_CLASS = "claims.class"; /** *

* {@link Property} name for the startup timeout in milliseconds. *

* This is the time that {@link HttpRequest} instances are held up waiting the * for the initial {@link JwtValidateKey} instances to be loaded. */ public static final String PROEPRTY_STARTUP_TIMEOUT = "startup.timeout"; /** * Default value for {@link #PROEPRTY_STARTUP_TIMEOUT}. */ public static final long DEFAULT_STARTUP_TIMEOUT = 1 * 1000; /** * {@link Property} name for the clock skew in seconds. */ public static final String PROPERTY_CLOCK_SKEW = "clock.skew"; /** * Default value for {@link #PROPERTY_CLOCK_SKEW}. */ public static final long DEFAULT_CLOCK_SKEW = 2; /** * Flow keys. */ public static enum Flows { RETRIEVE_KEYS, RETRIEVE_ROLES, NO_JWT, INVALID_JWT, EXPIRED_JWT } /** * {@link HttpRequestState} attribute name for the {@link ChallengeReason}. */ private static final String CHALLENGE_ATTRIBUTE_NAME = "challenge.reason"; /** * Base64 {@link Decoder}. */ private Decoder base64UrlDecoder = (text) -> Base64.getUrlDecoder().decode(text); /** * {@link ObjectMapper}. */ private static final ObjectMapper mapper = new ObjectMapper(); /** * {@link JwtClaims} {@link JavaType}. */ private static final JavaType jwtClaimsJavaType = mapper.constructType(JwtClaims.class); /** * {@link JwtHeader} {@link JavaType}. */ private static final JavaType jwtHeaderJavaType = mapper.constructType(JwtHeader.class); static { // Ensure JSON deserialising is valid if (!mapper.canDeserialize(jwtClaimsJavaType)) { throw new IllegalStateException("Unable to deserialize " + JwtClaims.class.getSimpleName()); } if (!mapper.canDeserialize(jwtHeaderJavaType)) { throw new IllegalStateException("Unable to deserialize " + JwtHeader.class.getSimpleName()); } // Ensure ignore unknown properties (avoid added "exp" causing problems) mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); } /** * {@link Clock}. */ private Clock clock; /** * Skew in seconds for {@link Clock} time to coordinate with JWT authority. */ private long clockSkew; /** * Claims {@link Class}. */ private Class claimsClass; /** * Claims {@link Class} {@link JavaType}. */ private JavaType claimsJavaType; /** * Start up timeout. */ private long startupTimeout; /** * {@link StatePoller} to keep the {@link JwtValidateKey} instances up to date. */ private StatePoller jwtValidateKeys; /** * {@link JwtKeysFactoryOverride}. Will be null if not overriding. */ private JwtKeysFactoryOverride keysOverride; /* * ==================== HttpSecuritySource ============================ */ @Override protected void loadSpecification(SpecificationContext context) { context.addProperty(PROPERTY_CLAIMS_CLASS, "Claims Class"); } @Override @SuppressWarnings({ "unchecked", "rawtypes" }) protected void loadMetaData( MetaDataContext, JwtHttpAccessControl, Void, None, Flows> context) throws Exception { HttpSecuritySourceContext securityContext = context.getHttpSecuritySourceContext(); // Load the configuration this.clock = securityContext.getClock((time) -> time); this.clockSkew = Long .parseLong(securityContext.getProperty(PROPERTY_CLOCK_SKEW, String.valueOf(DEFAULT_CLOCK_SKEW))); this.claimsClass = securityContext.loadClass(securityContext.getProperty(PROPERTY_CLAIMS_CLASS)); this.startupTimeout = Long.parseLong( securityContext.getProperty(PROEPRTY_STARTUP_TIMEOUT, String.valueOf(DEFAULT_STARTUP_TIMEOUT))); // Ensure claims class can be deserialised this.claimsJavaType = mapper.constructType(this.claimsClass); if (!mapper.canDeserialize(this.claimsJavaType)) { throw new IOException("Unable to deserialise " + this.claimsClass.getName() + " to load JWT claims"); } // Load the possible JWT keys override this.keysOverride = threadLocalKeysOverride.get(); if ((this.keysOverride == null) || (this.keysOverride.validateKeysFactory == null)) { // Only override if have key factory this.keysOverride = null; } // Load the meta-data context.setAuthenticationClass((Class) HttpAuthentication.class); context.setAccessControlClass((Class) JwtHttpAccessControl.class); // Provide flow to retrieve keys and obtain roles context.addFlow(Flows.RETRIEVE_KEYS, JwtValidateKeyCollector.class); context.addFlow(Flows.RETRIEVE_ROLES, JwtRoleCollector.class); // Provide challenge flows context.addFlow(Flows.NO_JWT, null); context.addFlow(Flows.INVALID_JWT, null); context.addFlow(Flows.EXPIRED_JWT, null); // Provide the JWT claims HttpSecuritySupportingManagedObject jwtClaims = securityContext.addSupportingManagedObject( "JWT_CLAIMS", new JwtClaimsManagedObjectSource(this.claimsClass), ManagedObjectScope.THREAD); jwtClaims.linkAccessControl(Dependencies.ACCESS_CONTROL); } @Override public HttpSecurity, JwtHttpAccessControl, Void, None, Flows> sourceHttpSecurity( HttpSecurityContext context) throws HttpException { return this; } @Override public void start(HttpSecurityExecuteContext context) throws Exception { // Only poll if not overriding JWT keys if (this.keysOverride != null) { return; // override so do not poll for keys } // Create poller for JWT decoder this.jwtValidateKeys = StatePoller.builder(JwtValidateKey[].class, (pollContext, callback) -> { ManagedObjectStartupProcess startup = context.registerStartupProcess(Flows.RETRIEVE_KEYS, new JwtValidateKeyCollectorImpl(pollContext), callback); // Run concurrently (may wait on local JWT Authority initialised later) startup.setConcurrent(true); }, (delay, pollContext, callback) -> { context.invokeProcess(Flows.RETRIEVE_KEYS, new JwtValidateKeyCollectorImpl(pollContext), delay, callback); }).identifier("JWT decode keys").build(); } /* * ====================== HttpSecurity ========================== */ @Override public HttpAuthentication createAuthentication(AuthenticationContext, Void> context) { return new HttpAuthenticationImpl<>(context, null); } @Override public boolean ratify(Void credentials, RatifyContext> context) { // Determine if bearer credentials on request HttpAuthenticationScheme scheme = HttpAuthenticationScheme .getHttpAuthenticationScheme(context.getConnection().getRequest()); if ((scheme == null) || (!(AUTHENTICATION_SCHEME_BEARER.equalsIgnoreCase(scheme.getAuthentiationScheme())))) { // Flag for potential challenge that no JWT context.getRequestState().setAttribute(context.getQualifiedAttributeName(CHALLENGE_ATTRIBUTE_NAME), ChallengeReason.NO_JWT); return false; // no JWT } // Has JWT so enough information to authenticate return true; } @Override public void authenticate(Void credentials, AuthenticateContext, None, Flows> context) throws HttpException { // Obtain the JWT validate keys (allowing time to intialise) JwtValidateKey[] validateKeys; if (this.keysOverride != null) { // Override the keys try { validateKeys = this.keysOverride.validateKeysFactory.createJwtValidateKeys(); } catch (Exception ex) { context.accessControlChange(null, new HttpException(HttpStatus.SERVICE_UNAVAILABLE, ex)); return; // must obtain validate keys } } else { // Use polled keys try { validateKeys = this.jwtValidateKeys.getState(this.startupTimeout, TimeUnit.MILLISECONDS); } catch (TimeoutException ex) { context.accessControlChange(null, new HttpException(HttpStatus.SERVICE_UNAVAILABLE, new TimeoutException("Server timed out loading JWT keys"))); return; // must obtain validate keys } } // Obtain the scheme HttpAuthenticationScheme scheme = HttpAuthenticationScheme .getHttpAuthenticationScheme(context.getConnection().getRequest()); String jwtToken = scheme.getParameters(); // Split out the JWT String[] jwtParts = jwtToken.split("\\."); if (jwtParts.length != 3) { // Must have header, claims and signature this.challenge(ChallengeReason.INVALID_JWT, context); return; } // Obtain the parts String headerBase64 = jwtParts[0]; String claimsBase64 = jwtParts[1]; String signatureBase64 = jwtParts[2]; /* * Undertake parsing out JWT claims and validating the signature. * * Note: order of operations is least expensive to most expensive to reduce load * on the server. This ensures a server under load has best chance to handle CPU * processing in validating JWTs from many HTTP requests. */ // Obtain the claims byte[] claimsBytes = base64UrlDecoder.decode(claimsBase64); JwtClaims validateClaims; try { validateClaims = mapper.readValue(claimsBytes, jwtClaimsJavaType); } catch (IOException e) { // Must be able to parse claims this.challenge(ChallengeReason.INVALID_JWT, context); return; } // Obtain the current time (in seconds) long currentTime = this.clock.getTime(); // Ensure valid window (taking into account clock skew) // Note: signature will only confirm not yet available if ((validateClaims.nbf != null) && (validateClaims.nbf > (currentTime + this.clockSkew))) { // JWT not yet active this.challenge(ChallengeReason.INVALID_JWT, context); return; } // Ensure not expired (taking into account clock skew) // Note: signature will only confirm expired if ((validateClaims.exp != null) && (validateClaims.exp < (currentTime - this.clockSkew))) { // JWT expired this.challenge(ChallengeReason.EXPIRED_JWT, context); return; } // Obtain the signature algorithm byte[] headerBytes = base64UrlDecoder.decode(headerBase64); JwtHeader header; try { header = mapper.readValue(headerBytes, jwtHeaderJavaType); } catch (IOException ex) { // Must be able to parse header this.challenge(ChallengeReason.INVALID_JWT, context); return; } // Ensure have algorithm if (header.alg == null) { this.challenge(ChallengeReason.INVALID_JWT, context); return; } // Obtain the algorithm SignatureAlgorithm algorithm = SignatureAlgorithm.valueOf(header.alg); if ((algorithm == null) || (algorithm == SignatureAlgorithm.NONE)) { this.challenge(ChallengeReason.INVALID_JWT, context); return; } // Obtain the JWT without signature String jwtWithoutSignature = jwtToken.substring(0, headerBase64.length() + ".".length() + claimsBase64.length()); // Loop over decode keys to determine if JWT valid boolean isValid = false; NEXT_DECODE_KEY: for (JwtValidateKey decodeKey : validateKeys) { // Ensure key is still within window (taking into account clock skew) if ((decodeKey.getStartTime() > (currentTime + this.clockSkew)) || (decodeKey.getExpireTime() < (currentTime - this.clockSkew))) { continue NEXT_DECODE_KEY; // decode key now outside window } // Attempt to validate the signature DefaultJwtSignatureValidator validator = new DefaultJwtSignatureValidator(algorithm, decodeKey.getKey(), base64UrlDecoder); try { if (validator.isValid(jwtWithoutSignature, signatureBase64)) { isValid = true; break NEXT_DECODE_KEY; // is valid, so no further processing } } catch (Exception ex) { // Ignore as signature not valid } } if (!isValid) { this.challenge(ChallengeReason.INVALID_JWT, context); return; } // Load the claims object for application C claims; try { claims = mapper.readValue(claimsBytes, this.claimsJavaType); } catch (IOException ex) { // Must be able to parse claims this.challenge(ChallengeReason.INVALID_JWT, context); return; } // Retrieve the roles String authenticationScheme = scheme.getAuthentiationScheme(); String principalName = validateClaims.sub; JwtRoleCollectorImpl rolesCollector = new JwtRoleCollectorImpl(claims, authenticationScheme, principalName, context); context.doFlow(Flows.RETRIEVE_ROLES, rolesCollector, rolesCollector); } /** * Loads the challenge details. * * @param reason {@link ChallengeReason}. * @param context {@link AuthenticateContext}. */ private void challenge(ChallengeReason reason, AuthenticateContext, None, Flows> context) { context.getRequestState().setAttribute(context.getQualifiedAttributeName(CHALLENGE_ATTRIBUTE_NAME), reason); } @Override public void challenge(ChallengeContext context) throws HttpException { // Challenge, so unauthorised by default HttpResponse response = context.getConnection().getResponse(); response.setStatus(HttpStatus.UNAUTHORIZED); // Determine cause of challenge ChallengeReason reason = (ChallengeReason) context.getRequestState() .getAttribute(context.getQualifiedAttributeName(CHALLENGE_ATTRIBUTE_NAME)); switch (reason) { case NO_JWT: context.doFlow(Flows.NO_JWT, null, null); break; case INVALID_JWT: context.doFlow(Flows.INVALID_JWT, null, null); break; case EXPIRED_JWT: context.doFlow(Flows.EXPIRED_JWT, null, null); break; } } @Override public void logout(LogoutContext context) throws HttpException { // Not able to "logout" JWT token (as typically externally managed) } /** * Challenge reason. */ private static enum ChallengeReason { NO_JWT, INVALID_JWT, EXPIRED_JWT } /** * {@link JwtValidateKeyCollector} implementation. */ private class JwtValidateKeyCollectorImpl implements JwtValidateKeyCollector { /** * {@link StatePollContext} for the {@link JwtValidateKey} instances. */ private final StatePollContext context; /** * Instantiate. * * @param context {@link StatePollContext} for the {@link JwtValidateKey} * instances. */ private JwtValidateKeyCollectorImpl(StatePollContext context) { this.context = context; } /* * ============= JwtDecodeCollector =============== */ @Override public JwtValidateKey[] getCurrentKeys() { return JwtHttpSecuritySource.this.jwtValidateKeys.getStateNow(); } @Override public void setKeys(JwtValidateKey... keys) { // Filter the keys (also make copy so can not alter) List copy = new ArrayList<>(keys.length); NEXT_KEY: for (JwtValidateKey key : keys) { // Ignore if null if (key == null) { continue NEXT_KEY; } // As here valid, so include decode key copy.add(key); } // Load the JWT decode keys JwtValidateKey[] validKeys = copy.toArray(new JwtValidateKey[copy.size()]); this.context.setNextState(validKeys, -1, null); } @Override public void setFailure(Throwable cause, long timeToNextCheck, TimeUnit unit) { this.context.setFailure(cause, timeToNextCheck, unit); } } /** * {@link JwtRoleCollector} implementation. */ private class JwtRoleCollectorImpl implements JwtRoleCollector, FlowCallback { /** * Claims. */ private final C claims; /** * Authentication scheme. */ private final String authenticationScheme; /** * {@link Principal} name. */ private final String principalName; /** * {@link AuthenticateContext}. */ private final AuthenticateContext, None, Flows> authenticateContext; /** * Indicates if completed. */ private volatile boolean isComplete = false; /** * Instantiate. * * @param claims Claims. * @param authenticationScheme Authentication scheme. * @param principalName {@link Principal} name. * @param authenticateContext {@link AuthenticateContext}. */ private JwtRoleCollectorImpl(C claims, String authenticationScheme, String principalName, AuthenticateContext, None, Flows> authenticateContext) { this.claims = claims; this.authenticationScheme = authenticationScheme; this.principalName = principalName; this.authenticateContext = authenticateContext; } /* * =============== JwtRoleCollector ================= */ @Override public C getClaims() { return this.claims; } @Override public void setRoles(Collection roles) { // Determine if complete if (this.isComplete) { return; } this.isComplete = true; // Create copy of roles (to ensure serialisable) Set rolesSet = new HashSet<>(roles); // Create the Jwt HttpAccess JwtHttpAccessControl accessControl = new JwtHttpAccessControlImpl(this.authenticationScheme, this.principalName, this.claims, rolesSet); this.authenticateContext.accessControlChange(accessControl, null); } @Override public void setFailure(Throwable cause) { // Determine if complete if (this.isComplete) { return; } this.isComplete = true; // Flag the failure this.authenticateContext.accessControlChange(null, cause); } /* * ================ FlowCallback ==================== */ @Override public void run(Throwable escalation) throws Throwable { // Determine if already complete if (this.isComplete) { return; } // Ensure have escalation if (escalation == null) { escalation = new HttpException(HttpStatus.FORBIDDEN, new IllegalStateException("No roles loaded for JWT claims")); } // Indicate failure to load roles this.authenticateContext.accessControlChange(null, escalation); } } /** * {@link JwtHttpAccessControl} implementation. */ private class JwtHttpAccessControlImpl extends HttpAccessControlImpl implements JwtHttpAccessControl { /** * Serial version UID. */ private static final long serialVersionUID = 1L; /** * Claims. */ private final C claims; /** * Instantiate. * * @param authenticationScheme Authentication scheme. * @param principalName {@link Principal} name. * @param claims Claims. * @param roles Roles. */ public JwtHttpAccessControlImpl(String authenticationScheme, String principalName, C claims, Set roles) { super(authenticationScheme, principalName, roles); this.claims = claims; } /* * ================ JwtHttpAccessControl ================= */ @Override public C getClaims() { return this.claims; } } /** * JWT header. */ @JsonIgnoreProperties(ignoreUnknown = true) public static class JwtHeader { /** * Algorithm. */ private String alg; /** * Specifies the algorithm. * * @param alg Algorithm. */ public void setAlg(String alg) { this.alg = alg; } } /** * JWT claims. */ @JsonIgnoreProperties(ignoreUnknown = true) public static class JwtClaims { /** * Subject. */ private String sub; /** * Expiry time. */ private Long exp; /** * Not before time. */ private Long nbf; /** * Specifies the subject. * * @param sub Subject. */ public void setSub(String sub) { this.sub = sub; } /** * Specifies expiry time. * * @param exp Expiry time. */ public void setExp(Long exp) { this.exp = exp; } /** * Specifies not before time. * * @param nbf Not before time. */ public void setNbf(Long nbf) { this.nbf = nbf; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy