
com.peterphi.std.guice.web.rest.auth.userprovider.HttpCallJWTUser Maven / Gradle / Ivy
package com.peterphi.std.guice.web.rest.auth.userprovider;
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.JWTVerifyException;
import com.google.common.base.Objects;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.peterphi.std.guice.apploader.GuiceConstants;
import com.peterphi.std.guice.common.auth.iface.AccessRefuser;
import com.peterphi.std.guice.common.auth.iface.CurrentUser;
import com.peterphi.std.guice.restclient.exception.RestException;
import com.peterphi.std.guice.web.HttpCallContext;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
import org.jboss.resteasy.util.HttpHeaderNames;
import org.joda.time.DateTime;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SignatureException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
class HttpCallJWTUser implements CurrentUser
{
private static final Logger log = Logger.getLogger(HttpCallJWTUser.class);
/**
* {@link HttpServletRequest} attribute to store the decoded token under
*/
private static final String DECODED_JWT = "decoded-jwt";
/**
* Token cache; lets us skip parsing (N.B. need to revalidate "exp" if set)
*/
private static final Cache> TOKEN_CACHE = CacheBuilder.newBuilder().maximumSize(256).build();
private final String headerName;
private final String cookieName;
private final boolean requireSecure;
private final JWTVerifier verifier;
public HttpCallJWTUser(final String headerName,
final String cookieName,
final boolean requireSecure,
final JWTVerifier verifier)
{
this.headerName = headerName;
this.cookieName = cookieName;
this.requireSecure = requireSecure;
this.verifier = verifier;
}
@Override
public String getAuthType()
{
return GuiceConstants.JAXRS_SERVER_WEBAUTH_JWT_PROVIDER;
}
@SuppressWarnings("unchecked")
private Map get()
{
final HttpServletRequest request = HttpCallContext.get().getRequest();
// Allow the parsed token to be used without revalidation for the duration of an HTTP request
{
final Object obj = request.getAttribute(DECODED_JWT);
if (obj != null && obj instanceof Map)
return (Map) (Map) obj;
}
final String token = getToken(request);
if (token != null)
{
// Parse and validate token
try
{
final Map data = parseToken(token);
if (requireSecure && !request.isSecure())
{
log.fatal("JWT received over insecure channel (but secure channel mandated)! Token is probably compromised:" +
data);
return null;
}
// Cache without revalidation for the remainder of this HTTP request
request.setAttribute(DECODED_JWT, data);
return data;
}
catch (Exception e)
{
throw new RuntimeException("JWT Verification failed!", e);
}
}
else
{
throw new RuntimeException("User is authenticated by JWT but no JWT found!");
}
}
private Map parseToken(final String token) throws NoSuchAlgorithmException, InvalidKeyException, IOException, SignatureException, JWTVerifyException
{
final Map cached = TOKEN_CACHE.getIfPresent(token);
// N.B. If token is present in cache then we need to revalidate expire time
if (cached != null)
{
// If the token has an expire time then we must check it has not passed yet
// If the token has no expire time then the cache entry can be returned immediately
if (cached.containsKey("exp"))
{
// Token has an expire time, must revalidate before returning
// exp holds seconds since 1970 timestamp for the expiry time
final long expireTimestamp = 1000L * (long) cached.get("exp");
// N.B. if we don't return the cached value we'll
if (expireTimestamp < System.currentTimeMillis())
return cached;
else
TOKEN_CACHE.invalidate(token); // invalidate this cache entry, it has expired. The fallback code at the end of this method will run and provide a consistent token expired error message
}
else
{
// Token has no expire time, can be returned as-is
return cached;
}
}
// Fall back on parsing the token
Map parsed = verifier.verify(token);
// Store the token in the cache
TOKEN_CACHE.put(token, parsed);
return parsed;
}
/**
* Returns true if and only if a JWT has been provided, but does not validate the token
*
* @return
*/
public boolean hasToken()
{
final HttpServletRequest request = HttpCallContext.get().getRequest();
return getToken(request) != null;
}
String getToken(final HttpServletRequest request)
{
// Try a bare HTTP Header
if (headerName != null)
{
final String val = request.getHeader(this.headerName);
if (val != null)
return val;
}
// Try the HTTP Authorization header
{
final String header = request.getHeader(HttpHeaderNames.AUTHORIZATION);
if (header != null)
{
return parseBasicAuth(header);
}
}
// Try reading from a cookie
if (cookieName != null)
{
final Cookie[] cookies = request.getCookies();
if (cookies != null)
for (Cookie cookie : cookies)
if (StringUtils.equals(cookie.getName(), cookieName))
return cookie.getValue();
}
// No token found
return null;
}
static String parseBasicAuth(final String header)
{
try
{
if (header.length() < 6)
return null;
if (StringUtils.startsWithIgnoreCase(header, "Basic"))
{
// JWT bundled into HTTP Basic auth
String val = header.substring("Basic".length() + 1);
val = new String(org.apache.commons.codec.binary.Base64.decodeBase64(val.getBytes()), "UTF-8");
String[] split = val.split(":", 2);
if (split.length != 2)
return null;
else if (StringUtils.equals("jwt", split[0])) // Username jwt
return split[1];
else
return null;
}
else if (StringUtils.startsWithIgnoreCase(header, "Bearer"))
{
// Bearer token
return header.substring("Bearer".length() + 1);
}
else
{
return null; // unrecognised auth type
}
}
catch (Exception e)
{
log.warn("Error extracting JWT from HTTP Auth header", e);
return null;
}
}
@Override
public boolean isAnonymous()
{
return (get() != null);
}
@Override
public String getName()
{
final Map data = get();
if (data.containsKey("name"))
return String.valueOf(data.get("name"));
else
return getUsername();
}
@Override
public String getUsername()
{
final Map data = get();
if (data.containsKey("sub"))
return String.valueOf(data.get("sub"));
else
return getName();
}
@Override
@SuppressWarnings("unchecked")
public boolean hasRole(final String role)
{
final Map data = get();
// Special case the "authenticated" role - need only have a valid token, no role is required
if (StringUtils.equals(WebappAuthenticationModule.ROLE_SPECIAL_AUTHENTICATED, role))
{
return data != null;
}
else
{
final List roles = (List) data.get("roles");
if (roles == null)
return false; // No roles specified!
else
return roles.contains(role);
}
}
@Override
public AccessRefuser getAccessRefuser()
{
return (scope, constraint, user) ->
{
if (user.isAnonymous())
return new RestException(401,
"You must log in to access this resource. Required role: " + scope.getRole(constraint));
else
return new RestException(403,
"Access denied for your JWT by rule: " +
((constraint != null) ?
constraint.comment() :
"(default)" + ". Required role: " + scope.getRole(constraint)));
};
}
@Override
public DateTime getExpires()
{
final Number epochSeconds = (Number) get().get("exp");
if (epochSeconds != null)
{
final long millis = epochSeconds.longValue() * 1000L;
return new DateTime(millis);
}
else
{
return null; // No expire time set
}
}
@Override
public Map getClaims()
{
return Collections.unmodifiableMap(get());
}
@Override
public String toString()
{
return Objects.toStringHelper(this).add("claims", get()).toString();
}
}