All Downloads are FREE. Search and download functionalities are using the official Maven repository.

uk.gov.di.ipv.cri.common.library.service.JWTVerifier Maven / Gradle / Ivy

package uk.gov.di.ipv.cri.common.library.service;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.crypto.ECDSAVerifier;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jose.crypto.impl.ECDSA;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimNames;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.oauth2.sdk.id.ClientID;
import uk.gov.di.ipv.cri.common.library.exception.ClientConfigurationException;
import uk.gov.di.ipv.cri.common.library.exception.SessionValidationException;

import java.io.ByteArrayInputStream;
import java.security.PublicKey;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Map;
import java.util.Set;

import static com.nimbusds.jose.JWSAlgorithm.ES256;

public class JWTVerifier {

    public void verifyAuthorizationJWT(
            Map clientAuthenticationConfig, SignedJWT signedJWT)
            throws SessionValidationException, ClientConfigurationException {
        verifyJWT(
                clientAuthenticationConfig,
                signedJWT,
                Set.of(
                        JWTClaimNames.EXPIRATION_TIME,
                        JWTClaimNames.SUBJECT,
                        JWTClaimNames.NOT_BEFORE),
                new JWTClaimsSet.Builder()
                        .issuer(clientAuthenticationConfig.get("issuer"))
                        .audience(clientAuthenticationConfig.get("audience"))
                        .build());
    }

    public void verifyAccessTokenJWT(
            Map clientAuthenticationConfig, SignedJWT signedJWT, ClientID clientID)
            throws SessionValidationException, ClientConfigurationException {
        Set requiredClaims =
                Set.of(
                        JWTClaimNames.EXPIRATION_TIME,
                        JWTClaimNames.SUBJECT,
                        JWTClaimNames.ISSUER,
                        JWTClaimNames.AUDIENCE,
                        JWTClaimNames.JWT_ID);
        JWTClaimsSet expectedClaimValues =
                new JWTClaimsSet.Builder()
                        .issuer(clientID.getValue())
                        .subject(clientID.getValue())
                        .audience(clientAuthenticationConfig.get("audience"))
                        .build();
        verifyJWT(clientAuthenticationConfig, signedJWT, requiredClaims, expectedClaimValues);
    }

    public void validateMaxAllowedJarTtl(Instant jwtExpirationTime, long maxAllowedTtl)
            throws SessionValidationException {

        LocalDateTime maximumExpirationTime =
                LocalDateTime.ofInstant(
                        Instant.now().plus(maxAllowedTtl, ChronoUnit.SECONDS), ZoneOffset.UTC);
        LocalDateTime expirationTime = LocalDateTime.ofInstant(jwtExpirationTime, ZoneOffset.UTC);

        if (expirationTime.isAfter(maximumExpirationTime)) {
            throw new SessionValidationException(
                    "The client JWT expiry date has surpassed the maximum allowed ttl value");
        }
    }

    private void verifyJWT(
            Map clientAuthenticationConfig,
            SignedJWT signedJWT,
            Set requiredClaims,
            JWTClaimsSet expectedClaimValues)
            throws SessionValidationException, ClientConfigurationException {
        this.verifyJWTHeader(clientAuthenticationConfig, signedJWT);
        this.verifyJWTClaimsSet(signedJWT, requiredClaims, expectedClaimValues);
        this.verifyJWTSignature(clientAuthenticationConfig, signedJWT);
    }

    private void verifyJWTHeader(
            Map clientAuthenticationConfig, SignedJWT signedJWT)
            throws SessionValidationException {
        JWSAlgorithm configuredAlgorithm =
                JWSAlgorithm.parse(clientAuthenticationConfig.get("authenticationAlg"));
        JWSAlgorithm jwtAlgorithm = signedJWT.getHeader().getAlgorithm();
        if (jwtAlgorithm != configuredAlgorithm) {
            throw new SessionValidationException(
                    String.format(
                            "jwt signing algorithm %s does not match signing algorithm configured for client: %s",
                            jwtAlgorithm, configuredAlgorithm));
        }
    }

    private void verifyJWTSignature(
            Map clientAuthenticationConfig, SignedJWT signedJWT)
            throws SessionValidationException, ClientConfigurationException {
        String publicCertificateToVerify = clientAuthenticationConfig.get("publicSigningJwkBase64");
        try {

            SignedJWT concatSignatureJwt;
            if (signatureIsDerFormat(signedJWT)) {
                concatSignatureJwt = transcodeSignature(signedJWT);
            } else {
                concatSignatureJwt = signedJWT;
            }
            JWSAlgorithm signingAlgorithm = signedJWT.getHeader().getAlgorithm();
            PublicKey pubicKeyFromConfig =
                    getPublicKeyFromConfig(publicCertificateToVerify, signingAlgorithm);
            if (!verifySignature(concatSignatureJwt, pubicKeyFromConfig)) {
                throw new SessionValidationException("JWT signature verification failed");
            }
        } catch (JOSEException | ParseException e) {
            throw new SessionValidationException("JWT signature verification failed", e);
        } catch (CertificateException e) {
            throw new ClientConfigurationException("Certificate problem encountered", e);
        }
    }

    private boolean signatureIsDerFormat(SignedJWT signedJWT) throws JOSEException {
        return signedJWT.getSignature().decode().length != ECDSA.getSignatureByteArrayLength(ES256);
    }

    private SignedJWT transcodeSignature(SignedJWT signedJWT) throws JOSEException, ParseException {
        Base64URL transcodedSignatureBase64 =
                Base64URL.encode(
                        ECDSA.transcodeSignatureToConcat(
                                signedJWT.getSignature().decode(),
                                ECDSA.getSignatureByteArrayLength(ES256)));
        String[] jwtParts = signedJWT.serialize().split("\\.");
        return SignedJWT.parse(
                String.format("%s.%s.%s", jwtParts[0], jwtParts[1], transcodedSignatureBase64));
    }

    private void verifyJWTClaimsSet(
            SignedJWT signedJWT, Set requiredClaims, JWTClaimsSet expectedClaimValues)
            throws SessionValidationException {

        try {
            new DefaultJWTClaimsVerifier<>(expectedClaimValues, requiredClaims)
                    .verify(signedJWT.getJWTClaimsSet(), null);

        } catch (BadJWTException | ParseException e) {
            throw new SessionValidationException(e.getMessage(), e);
        }
    }

    private PublicKey getPublicKeyFromConfig(
            String serialisedPublicKey, JWSAlgorithm signingAlgorithm)
            throws CertificateException, ParseException, JOSEException {
        if (JWSAlgorithm.Family.RSA.contains(signingAlgorithm)) {
            byte[] binaryCertificate = Base64.getDecoder().decode(serialisedPublicKey);
            CertificateFactory factory = CertificateFactory.getInstance("X.509");
            Certificate certificate =
                    factory.generateCertificate(new ByteArrayInputStream(binaryCertificate));
            return certificate.getPublicKey();
        } else if (JWSAlgorithm.Family.EC.contains(signingAlgorithm)) {
            return ECKey.parse(new String(Base64.getDecoder().decode(serialisedPublicKey)))
                    .toECPublicKey();
        } else {
            throw new IllegalArgumentException(
                    "Unexpected signing algorithm encountered: " + signingAlgorithm.getName());
        }
    }

    private boolean verifySignature(SignedJWT signedJWT, PublicKey clientPublicKey)
            throws JOSEException, ClientConfigurationException {
        if (clientPublicKey instanceof RSAPublicKey) {
            RSASSAVerifier rsassaVerifier = new RSASSAVerifier((RSAPublicKey) clientPublicKey);
            return signedJWT.verify(rsassaVerifier);
        } else if (clientPublicKey instanceof ECPublicKey) {
            ECDSAVerifier ecdsaVerifier = new ECDSAVerifier((ECPublicKey) clientPublicKey);
            return signedJWT.verify(ecdsaVerifier);
        } else {
            throw new ClientConfigurationException(
                    new IllegalStateException(
                            "unknown public signing key: " + clientPublicKey.getAlgorithm()));
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy