All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.springframework.security.jwt.crypto.sign.RsaKeyHelper Maven / Gradle / Ivy

/*
 * Copyright 2006-2011 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */
package org.springframework.security.jwt.crypto.sign;

import static org.springframework.security.jwt.codec.Codecs.b64Decode;
import static org.springframework.security.jwt.codec.Codecs.utf8Encode;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.*;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.*;
import java.util.Arrays;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.bouncycastle.asn1.ASN1Sequence;

/**
 * Reads RSA key pairs using BC provider classes but without the
 * need to specify a crypto provider or have BC added as one.
 *
 * @author Luke Taylor
 */
class RsaKeyHelper {
	private static String BEGIN = "-----BEGIN";
	private static Pattern PEM_DATA = Pattern.compile("-----BEGIN (.*)-----(.*)-----END (.*)-----", Pattern.DOTALL);

	static KeyPair parseKeyPair(String pemData) {
		Matcher m = PEM_DATA.matcher(pemData.trim());

		if (!m.matches()) {
			throw new IllegalArgumentException("String is not PEM encoded data");
		}

		String type = m.group(1);
		final byte[] content = b64Decode(utf8Encode(m.group(2)));

		PublicKey publicKey;
		PrivateKey privateKey = null;

		try {
			KeyFactory fact = KeyFactory.getInstance("RSA");
			if (type.equals("RSA PRIVATE KEY")) {
				ASN1Sequence seq = ASN1Sequence.getInstance(content);
				if (seq.size() != 9) {
					throw new IllegalArgumentException("Invalid RSA Private Key ASN1 sequence.");
				}
				org.bouncycastle.asn1.pkcs.RSAPrivateKey key = org.bouncycastle.asn1.pkcs.RSAPrivateKey.getInstance(seq);
				RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent());
				RSAPrivateCrtKeySpec privSpec = new RSAPrivateCrtKeySpec(key.getModulus(), key.getPublicExponent(),
					key.getPrivateExponent(), key.getPrime1(), key.getPrime2(), key.getExponent1(), key.getExponent2(),
					key.getCoefficient());
				publicKey = fact.generatePublic(pubSpec);
				privateKey = fact.generatePrivate(privSpec);
			} else if (type.equals("PUBLIC KEY")) {
				KeySpec keySpec = new X509EncodedKeySpec(content);
				publicKey = fact.generatePublic(keySpec);
			} else if (type.equals("RSA PUBLIC KEY")) {
				ASN1Sequence seq = ASN1Sequence.getInstance(content);
				org.bouncycastle.asn1.pkcs.RSAPublicKey key = org.bouncycastle.asn1.pkcs.RSAPublicKey.getInstance(seq);
				RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent());
				publicKey = fact.generatePublic(pubSpec);
			} else {
				throw new IllegalArgumentException(type + " is not a supported format");
			}

			return new KeyPair(publicKey, privateKey);
		}
		catch (InvalidKeySpecException e) {
			throw new RuntimeException(e);
		}
		catch (NoSuchAlgorithmException e) {
			throw new IllegalStateException(e);
		}
	}

	private static final Pattern SSH_PUB_KEY = Pattern.compile("ssh-(rsa|dsa) ([A-Za-z0-9/+]+=*) (.*)");

	static RSAPublicKey parsePublicKey(String key) {
		Matcher m = SSH_PUB_KEY.matcher(key);

		if (m.matches()) {
			String alg = m.group(1);
			String encKey = m.group(2);
			//String id = m.group(3);

			if (!"rsa".equalsIgnoreCase(alg)) {
				throw new IllegalArgumentException("Only RSA is currently supported, but algorithm was " + alg);
			}

			return parseSSHPublicKey(encKey);
		} else if (!key.startsWith(BEGIN)) {
			// Assume it's the plain Base64 encoded ssh key without the "ssh-rsa" at the start
			return parseSSHPublicKey(key);
		}

		KeyPair kp = parseKeyPair(key);

		if (kp.getPublic() == null) {
			throw new IllegalArgumentException("Key data does not contain a public key");
		}

		return (RSAPublicKey) kp.getPublic();
	}

	private static RSAPublicKey parseSSHPublicKey(String encKey) {
		final byte[] PREFIX = new byte[] {0,0,0,7, 's','s','h','-','r','s','a'};
		ByteArrayInputStream in = new ByteArrayInputStream(b64Decode(utf8Encode(encKey)));

		byte[] prefix = new byte[11];

		try {
			if (in.read(prefix) != 11 || !Arrays.equals(PREFIX, prefix)) {
				throw new IllegalArgumentException("SSH key prefix not found");
			}

			BigInteger e = new BigInteger(readBigInteger(in));
			BigInteger n = new BigInteger(readBigInteger(in));

			return createPublicKey(n, e);
		} catch (IOException e) {
			throw new RuntimeException(e);
		}
	}

	static RSAPublicKey createPublicKey(BigInteger n, BigInteger e) {
		try {
			return (RSAPublicKey) KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(n, e));
		}
		catch (Exception ex) {
			throw new RuntimeException(ex);
		}
	}

	private static byte[] readBigInteger(ByteArrayInputStream in) throws IOException {
		byte[] b = new byte[4];

		if (in.read(b) != 4) {
			throw new IOException("Expected length data as 4 bytes");
		}

		int l = (b[0] << 24) | (b[1] << 16) | (b[2] << 8) | b[3];

		b = new byte[l];

		if (in.read(b) != l) {
			throw new IOException("Expected " + l + " key bytes");
		}

		return b;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy