All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.peterphi.std.guice.web.rest.auth.userprovider.HttpCallJWTUser Maven / Gradle / Ivy

package com.peterphi.std.guice.web.rest.auth.userprovider;

import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.JWTVerifyException;
import com.google.common.base.Objects;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.peterphi.std.guice.apploader.GuiceConstants;
import com.peterphi.std.guice.common.auth.iface.AccessRefuser;
import com.peterphi.std.guice.common.auth.iface.CurrentUser;
import com.peterphi.std.guice.restclient.exception.RestException;
import com.peterphi.std.guice.web.HttpCallContext;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
import org.jboss.resteasy.util.HttpHeaderNames;
import org.joda.time.DateTime;

import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SignatureException;
import java.util.Collections;
import java.util.List;
import java.util.Map;

class HttpCallJWTUser implements CurrentUser
{
	private static final Logger log = Logger.getLogger(HttpCallJWTUser.class);


	/**
	 * {@link HttpServletRequest} attribute to store the decoded token under
	 */
	private static final String DECODED_JWT = "decoded-jwt";

	/**
	 * Token cache; lets us skip parsing (N.B. need to revalidate "exp" if set)
	 */
	private static final Cache> TOKEN_CACHE = CacheBuilder.newBuilder().maximumSize(256).build();

	private final String headerName;
	private final String cookieName;
	private final boolean requireSecure;
	private final JWTVerifier verifier;


	public HttpCallJWTUser(final String headerName,
	                       final String cookieName,
	                       final boolean requireSecure,
	                       final JWTVerifier verifier)
	{
		this.headerName = headerName;
		this.cookieName = cookieName;
		this.requireSecure = requireSecure;
		this.verifier = verifier;
	}


	@Override
	public String getAuthType()
	{
		return GuiceConstants.JAXRS_SERVER_WEBAUTH_JWT_PROVIDER;
	}


	@SuppressWarnings("unchecked")
	private Map get()
	{
		final HttpServletRequest request = HttpCallContext.get().getRequest();

		// Allow the parsed token to be used without revalidation for the duration of an HTTP request
		{
			final Object obj = request.getAttribute(DECODED_JWT);

			if (obj != null && obj instanceof Map)
				return (Map) (Map) obj;
		}

		final String token = getToken(request);

		if (token != null)
		{
			// Parse and validate token
			try
			{
				final Map data = parseToken(token);

				if (requireSecure && !request.isSecure())
				{
					log.fatal("JWT received over insecure channel (but secure channel mandated)! Token is probably compromised:" +
					          data);

					return null;
				}

				// Cache without revalidation for the remainder of this HTTP request
				request.setAttribute(DECODED_JWT, data);

				return data;
			}
			catch (Exception e)
			{
				throw new RuntimeException("JWT Verification failed!", e);
			}
		}
		else
		{
			throw new RuntimeException("User is authenticated by JWT but no JWT found!");
		}
	}


	private Map parseToken(final String token) throws NoSuchAlgorithmException, InvalidKeyException, IOException, SignatureException, JWTVerifyException
	{
		final Map cached = TOKEN_CACHE.getIfPresent(token);

		// N.B. If token is present in cache then we need to revalidate expire time
		if (cached != null)
		{
			// If the token has an expire time then we must check it has not passed yet
			// If the token has no expire time then the cache entry can be returned immediately
			if (cached.containsKey("exp"))
			{
				// Token has an expire time, must revalidate before returning

				// exp holds seconds since 1970 timestamp for the expiry time
				final long expireTimestamp = 1000L * (long) cached.get("exp");

				// N.B. if we don't return the cached value we'll
				if (expireTimestamp < System.currentTimeMillis())
					return cached;
				else
					TOKEN_CACHE.invalidate(token); // invalidate this cache entry, it has expired. The fallback code at the end of this method will run and provide a consistent token expired error message
			}
			else
			{
				// Token has no expire time, can be returned as-is
				return cached;
			}
		}

		// Fall back on parsing the token
		Map parsed = verifier.verify(token);

		// Store the token in the cache
		TOKEN_CACHE.put(token, parsed);

		return parsed;
	}


	/**
	 * Returns true if and only if a JWT has been provided, but does not validate the token
	 *
	 * @return
	 */
	public boolean hasToken()
	{
		final HttpServletRequest request = HttpCallContext.get().getRequest();

		return getToken(request) != null;
	}


	String getToken(final HttpServletRequest request)
	{
		// Try a bare HTTP Header
		if (headerName != null)
		{
			final String val = request.getHeader(this.headerName);

			if (val != null)
				return val;
		}

		// Try the HTTP Authorization header
		{
			final String header = request.getHeader(HttpHeaderNames.AUTHORIZATION);

			if (header != null)
			{
				return parseBasicAuth(header);
			}
		}

		// Try reading from a cookie
		if (cookieName != null)
		{
			final Cookie[] cookies = request.getCookies();

			if (cookies != null)
				for (Cookie cookie : cookies)
					if (StringUtils.equals(cookie.getName(), cookieName))
						return cookie.getValue();
		}

		// No token found
		return null;
	}


	static String parseBasicAuth(final String header)
	{
		try
		{
			if (header.length() < 6)
				return null;

			if (StringUtils.startsWithIgnoreCase(header, "Basic"))
			{
				// JWT bundled into HTTP Basic auth
				String val = header.substring("Basic".length() + 1);

				val = new String(org.apache.commons.codec.binary.Base64.decodeBase64(val.getBytes()), "UTF-8");

				String[] split = val.split(":", 2);

				if (split.length != 2)
					return null;
				else if (StringUtils.equals("jwt", split[0])) // Username jwt
					return split[1];
				else
					return null;
			}
			else if (StringUtils.startsWithIgnoreCase(header, "Bearer"))
			{
				// Bearer token
				return header.substring("Bearer".length() + 1);
			}
			else
			{
				return null; // unrecognised auth type
			}
		}
		catch (Exception e)
		{
			log.warn("Error extracting JWT from HTTP Auth header", e);
			return null;
		}
	}


	@Override
	public boolean isAnonymous()
	{
		return (get() != null);
	}


	@Override
	public String getName()
	{
		final Map data = get();

		if (data.containsKey("name"))
			return String.valueOf(data.get("name"));
		else
			return getUsername();
	}


	@Override
	public String getUsername()
	{
		final Map data = get();

		if (data.containsKey("sub"))
			return String.valueOf(data.get("sub"));
		else
			return getName();
	}


	@Override
	@SuppressWarnings("unchecked")
	public boolean hasRole(final String role)
	{
		final Map data = get();

		// Special case the "authenticated" role - need only have a valid token, no role is required
		if (StringUtils.equals(WebappAuthenticationModule.ROLE_SPECIAL_AUTHENTICATED, role))
		{
			return data != null;
		}
		else
		{
			final List roles = (List) data.get("roles");

			if (roles == null)
				return false; // No roles specified!
			else
				return roles.contains(role);
		}
	}


	@Override
	public AccessRefuser getAccessRefuser()
	{
		return (scope, constraint, user) ->
		{
			if (user.isAnonymous())
				return new RestException(401,
				                         "You must log in to access this resource. Required role: " + scope.getRole(constraint));
			else
				return new RestException(403,
				                         "Access denied for your JWT by rule: " +
				                         ((constraint != null) ?
				                          constraint.comment() :
				                          "(default)" + ". Required role: " + scope.getRole(constraint)));
		};
	}


	@Override
	public DateTime getExpires()
	{
		final Number epochSeconds = (Number) get().get("exp");

		if (epochSeconds != null)
		{
			final long millis = epochSeconds.longValue() * 1000L;

			return new DateTime(millis);
		}
		else
		{
			return null; // No expire time set
		}
	}


	@Override
	public Map getClaims()
	{
		return Collections.unmodifiableMap(get());
	}


	@Override
	public String toString()
	{
		return Objects.toStringHelper(this).add("claims", get()).toString();
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy