be.atbash.ee.security.octopus.jwt.decoder.JWTDecoder Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2017-2022 Rudy De Busscher (https://www.atbash.be)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package be.atbash.ee.security.octopus.jwt.decoder;
import be.atbash.ee.security.octopus.jwt.InvalidJWTException;
import be.atbash.ee.security.octopus.jwt.JWTEncoding;
import be.atbash.ee.security.octopus.jwt.JWTValidationConstant;
import be.atbash.ee.security.octopus.keys.selector.KeySelector;
import be.atbash.ee.security.octopus.nimbus.jwt.EncryptedJWT;
import be.atbash.ee.security.octopus.nimbus.jwt.JWTClaimsSet;
import be.atbash.ee.security.octopus.nimbus.jwt.PlainJWT;
import be.atbash.ee.security.octopus.nimbus.jwt.SignedJWT;
import be.atbash.ee.security.octopus.nimbus.jwt.proc.DefaultJWTProcessor;
import be.atbash.ee.security.octopus.nimbus.jwt.proc.JWTProcessor;
import be.atbash.ee.security.octopus.util.JsonbUtil;
import be.atbash.util.PublicAPI;
import be.atbash.util.StringUtils;
import be.atbash.util.exception.AtbashIllegalActionException;
import org.slf4j.MDC;
import javax.enterprise.context.ApplicationScoped;
import javax.json.JsonObject;
import javax.json.bind.Jsonb;
import java.text.ParseException;
import java.util.*;
/**
*
*/
@PublicAPI
@ApplicationScoped
public class JWTDecoder {
private JWTProcessor jwtProcessor;
public JWTData decode(String data, Class classType) {
return decode(data, classType, null, (JWTVerifier) null);
}
public JWTData decode(String data, Class classType, KeySelector keySelector) {
return decode(data, classType, keySelector, (JWTVerifier) null);
}
public JWTData decode(String data, Class classType, JWTVerifier verifier) {
return decode(data, classType, null, verifier);
}
public JWTData decode(String data, Class classType, KeySelector keySelector, String... defCritHeaders) {
return decode(data, classType, keySelector, null, defCritHeaders);
}
public JWTData decode(String data, Class classType, KeySelector keySelector, JWTVerifier verifier, String... defCritHeaders) {
JWTEncoding encoding = determineEncoding(data);
if (encoding == null) {
// These messages are in function of JWT validation by Atbash Runtime so have slightly narrow meaning of the provided parameters.
MDC.put(JWTValidationConstant.JWT_VERIFICATION_FAIL_REASON, "Unable to determine the encoding of the provided token");
throw new IllegalArgumentException("Unable to determine the encoding of the data");
}
MDC.put(JWTValidationConstant.JWT_VERIFICATION_FAIL_REASON, String.format("The encoding of the provided token : %s", encoding));
JWTData result;
try {
switch (encoding) {
case NONE:
result = readJSONString(data, classType);
break;
case PLAIN:
result = readPlainJWT(data, classType);
break;
case JWS:
if (keySelector == null) {
throw new AtbashIllegalActionException("(OCT-DEV-101) keySelector required for decoding a JWT encoded value");
}
result = readSignedJWT(data, keySelector, classType, verifier, getDefCritHeaders(defCritHeaders));
break;
case JWE:
if (keySelector == null) {
throw new AtbashIllegalActionException("(OCT-DEV-101) keySelector required for decoding a JWE encoded value");
}
result = readEncryptedJWT(data, keySelector, classType, verifier, getDefCritHeaders(defCritHeaders));
break;
default:
throw new IllegalArgumentException(String.format("JWTEncoding not supported %s", encoding));
}
} catch (ParseException e) {
// These messages are in function of JWT validation by Atbash Runtime so have slightly narrow meaning of the provided parameters.
MDC.put(JWTValidationConstant.JWT_VERIFICATION_FAIL_REASON, "The structure of the provided token was not valid");
throw new InvalidJWTException("Invalid JWT structure", e);
}
return result;
}
private HashSet getDefCritHeaders(String[] defCritHeaders) {
if (defCritHeaders == null) {
return new HashSet<>();
}
return new HashSet<>(Arrays.asList(defCritHeaders));
}
private JWTData readPlainJWT(String data, Class classType) throws ParseException {
PlainJWT plainJWT = PlainJWT.parse(data);
return handlePlainJWT(plainJWT, classType);
}
private JWTData readEncryptedJWT(String data, KeySelector keySelector, Class classType, JWTVerifier verifier, Set defCritHeaders) throws ParseException {
EncryptedJWT encryptedJWT = EncryptedJWT.parse(data);
return handleEncryptedJWT(encryptedJWT, keySelector, classType, verifier, defCritHeaders);
}
private JWTData readSignedJWT(String data, KeySelector keySelector, Class classType, JWTVerifier verifier, Set defCritHeaders) throws ParseException {
SignedJWT signedJWT = SignedJWT.parse(data);
return handleSignedJWT(signedJWT, keySelector, classType, verifier, defCritHeaders);
}
private JWTData readJSONString(String data, Class classType) {
return readJSONString(data, classType, new MetaJWTData());
}
private JWTData readJSONString(String data, Class classType, MetaJWTData metaJWTData) {
Jsonb jsonb = JsonbUtil.getJsonb();
return new JWTData<>(jsonb.fromJson(data, classType), metaJWTData);
}
/**
* Determine the encoding of the data. When it starts with { the encoding is none as it is plain
* JSON. When starting with ey and depending on the number of . found, the encoding is JWT or JWE.
* Otherwise the encoding is null.
* Note that the algorithm gives only an indication and that a wrong encoding can be returned (only false negatives)
* @param data the
* @return The encoding or null.
*/
public JWTEncoding determineEncoding(String data) {
if (data == null) {
return null;
}
JWTEncoding result = null;
if (data.startsWith("{")) {
result = JWTEncoding.NONE;
}
if (data.startsWith("ey")) {
int occurrences = StringUtils.countOccurrences(data, '.');
if (occurrences == 1) {
result = JWTEncoding.PLAIN;
}
if (occurrences == 2) {
int lastDot = data.lastIndexOf('.');
if (lastDot == data.length() - 1) {
result = JWTEncoding.PLAIN;
} else {
result = JWTEncoding.JWS;
}
}
if (occurrences == 4) {
result = JWTEncoding.JWE;
}
}
return result;
}
public JWTData decode(JsonObject data, Class classType) {
return decode(data, classType, null, (JWTVerifier) null);
}
public JWTData decode(JsonObject data, Class classType, KeySelector keySelector) {
return decode(data, classType, keySelector, (JWTVerifier) null);
}
public JWTData decode(JsonObject data, Class classType, JWTVerifier verifier) {
return decode(data, classType, null, verifier);
}
public JWTData decode(JsonObject data, Class classType, KeySelector keySelector, String... defCritHeaders) {
return decode(data, classType, keySelector, null, defCritHeaders);
}
public JWTData decode(JsonObject data, Class classType, KeySelector keySelector, JWTVerifier verifier, String... defCritHeaders) {
JWTEncoding encoding = determineEncoding(data);
if (encoding == null) {
throw new IllegalArgumentException("Unable to determine the encoding of the data");
}
JWTData result;
try {
switch (encoding) {
case PLAIN:
result = readPlainJWT(data, classType);
break;
case JWS:
if (keySelector == null) {
throw new AtbashIllegalActionException("(OCT-DEV-101) keySelector required for decoding a JWT encoded value");
}
result = readSignedJWT(data, keySelector, classType, verifier, getDefCritHeaders(defCritHeaders));
break;
case JWE:
if (keySelector == null) {
throw new AtbashIllegalActionException("(OCT-DEV-101) keySelector required for decoding a JWE encoded value");
}
result = readEncryptedJWT(data, keySelector, classType, verifier, getDefCritHeaders(defCritHeaders));
break;
default:
throw new IllegalArgumentException(String.format("JWTEncoding not supported %s", encoding));
}
} catch (ParseException e) {
throw new InvalidJWTException("Invalid JWT structure", e);
}
return result;
}
private JWTEncoding determineEncoding(JsonObject data) {
if (data == null) {
return null;
}
if (!(data.containsKey("header") || data.containsKey("protected")) && !data.containsKey("payload")) {
// payload and (header or protected) is required
return null;
}
JWTEncoding result = JWTEncoding.PLAIN;
if (data.containsKey("signature")) {
result = JWTEncoding.JWS;
}
if (data.containsKey("encrypted_key") && data.containsKey("iv") && data.containsKey("ciphertext") && data.containsKey("tag")) {
result = JWTEncoding.JWE;
}
return result;
}
private JWTData readPlainJWT(JsonObject data, Class classType) throws ParseException {
PlainJWT plainJWT = PlainJWT.parse(data);
return handlePlainJWT(plainJWT, classType);
}
private JWTData handlePlainJWT(PlainJWT plainJWT, Class classType) throws ParseException {
MetaJWTData metaJWTData = new MetaJWTData(null, plainJWT.getHeader().getCustomParameters());
JWTClaimsSet jwtClaimsSet = plainJWT.getJWTClaimsSet();
if (classType.equals(JWTClaimsSet.class)) {
return new JWTData<>((T) jwtClaimsSet, metaJWTData);
}
return readJSONString(jwtClaimsSet.toJSONObject().toString(), classType, metaJWTData);
}
private JWTData readSignedJWT(JsonObject data, KeySelector keySelector, Class classType, JWTVerifier verifier, Set defCritHeaders) throws ParseException {
SignedJWT signedJWT = SignedJWT.parse(data);
return handleSignedJWT(signedJWT, keySelector, classType, verifier, defCritHeaders);
}
private JWTData handleSignedJWT(SignedJWT signedJWT, KeySelector keySelector, Class classType, JWTVerifier verifier, Set defCritHeaders) throws ParseException {
JWTProcessor processor = getJwtProcessor();
processor.setJWSKeySelector(keySelector);
Set allCritHeaders = assembleAllCritHeaders(verifier, defCritHeaders);
processor.setDeferredCritHeaders(allCritHeaders);
JWTClaimsSet jwtClaimsSet = processor.process(signedJWT);
if (verifier != null && !verifier.verify(signedJWT.getHeader(), signedJWT.getJWTClaimsSet())) {
throw new InvalidJWTException("JWT verification failed");
}
String keyID = signedJWT.getHeader().getKeyID();
MetaJWTData metaJWTData = new MetaJWTData(keyID, signedJWT.getHeader().getCustomParameters());
if (classType.equals(JWTClaimsSet.class)) {
return new JWTData<>((T) jwtClaimsSet, metaJWTData);
}
return readJSONString(signedJWT.getPayload().toString(), classType, metaJWTData);
}
private Set assembleAllCritHeaders(JWTVerifier claimsVerifier, Set defCritHeaders) {
Set result = defCritHeaders;
if (result == null) {
result = new HashSet<>();
}
if (claimsVerifier != null) {
result.addAll(claimsVerifier.getSupportedCritHeaderValues());
}
return result;
}
private JWTData readEncryptedJWT(JsonObject data, KeySelector keySelector, Class classType, JWTVerifier verifier, Set defCritHeaders) throws ParseException {
EncryptedJWT encryptedJWT = EncryptedJWT.parse(data);
return handleEncryptedJWT(encryptedJWT, keySelector, classType, verifier, defCritHeaders);
}
private JWTData handleEncryptedJWT(EncryptedJWT encryptedJWT, KeySelector keySelector, Class classType, JWTVerifier verifier, Set defCritHeaders) {
String keyID = encryptedJWT.getHeader().getKeyID();
JWTProcessor processor = getJwtProcessor();
processor.setJWSKeySelector(keySelector);
processor.setJWEKeySelector(keySelector);
processor.setDeferredCritHeaders(defCritHeaders);
JWTClaimsSet jwtClaimsSet = processor.process(encryptedJWT);
if (verifier != null && !verifier.verify(encryptedJWT.getHeader(), jwtClaimsSet)) {
throw new InvalidJWTException("JWT verification failed");
}
MetaJWTData metaJWTData = new MetaJWTData(keyID, encryptedJWT.getHeader().getCustomParameters());
if (classType.equals(JWTClaimsSet.class)) {
return new JWTData<>((T) jwtClaimsSet, metaJWTData);
}
return readJSONString(jwtClaimsSet.toJSONObject().toString(), classType, metaJWTData);
}
private synchronized JWTProcessor getJwtProcessor() {
if (jwtProcessor == null) {
Iterator iterator = ServiceLoader.load(JWTProcessor.class).iterator();
if (iterator.hasNext()) {
jwtProcessor = iterator.next();
} else {
jwtProcessor = new DefaultJWTProcessor();
}
}
return jwtProcessor;
}
}