org.snf4j.tls.engine.ClientHelloConsumer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of snf4j-tls Show documentation
Show all versions of snf4j-tls Show documentation
The SNF4J module for TLS protocol
The newest version!
/*
* -------------------------------- MIT License --------------------------------
*
* Copyright (c) 2022-2024 SNF4J contributors
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
* -----------------------------------------------------------------------------
*/
package org.snf4j.tls.engine;
import static org.snf4j.tls.IntConstant.findMatch;
import static org.snf4j.tls.IntConstant.find;
import static org.snf4j.tls.extension.ExtensionsUtil.find;
import java.nio.ByteBuffer;
import java.security.KeyPair;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.snf4j.core.session.ssl.ClientAuth;
import org.snf4j.tls.alert.Alert;
import org.snf4j.tls.alert.DecryptErrorAlert;
import org.snf4j.tls.alert.HandshakeFailureAlert;
import org.snf4j.tls.alert.IllegalParameterAlert;
import org.snf4j.tls.alert.InternalErrorAlert;
import org.snf4j.tls.alert.MissingExtensionAlert;
import org.snf4j.tls.alert.ProtocolVersionAlert;
import org.snf4j.tls.alert.UnexpectedMessageAlert;
import org.snf4j.tls.alert.UnrecognizedNameAlert;
import org.snf4j.tls.cipher.CipherSuite;
import org.snf4j.tls.crypto.Hkdf;
import org.snf4j.tls.crypto.IHash;
import org.snf4j.tls.crypto.IHkdf;
import org.snf4j.tls.crypto.IKeyExchange;
import org.snf4j.tls.crypto.ITranscriptHash;
import org.snf4j.tls.crypto.KeySchedule;
import org.snf4j.tls.crypto.TranscriptHash;
import org.snf4j.tls.extension.ALPNExtension;
import org.snf4j.tls.extension.EarlyDataExtension;
import org.snf4j.tls.extension.ExtensionType;
import org.snf4j.tls.extension.IExtension;
import org.snf4j.tls.extension.IKeyShareExtension;
import org.snf4j.tls.extension.IPreSharedKeyExtension;
import org.snf4j.tls.extension.IPskKeyExchangeModesExtension;
import org.snf4j.tls.extension.IServerNameExtension;
import org.snf4j.tls.extension.ISignatureAlgorithmsExtension;
import org.snf4j.tls.extension.ISupportedGroupsExtension;
import org.snf4j.tls.extension.ISupportedVersionsExtension;
import org.snf4j.tls.extension.KeyShareEntry;
import org.snf4j.tls.extension.KeyShareExtension;
import org.snf4j.tls.extension.NamedGroup;
import org.snf4j.tls.extension.ParsedKey;
import org.snf4j.tls.extension.PreSharedKeyExtension;
import org.snf4j.tls.extension.PskKeyExchangeMode;
import org.snf4j.tls.extension.ServerNameExtension;
import org.snf4j.tls.extension.SignatureAlgorithmsCertExtension;
import org.snf4j.tls.extension.SignatureAlgorithmsExtension;
import org.snf4j.tls.extension.SignatureScheme;
import org.snf4j.tls.extension.SupportedGroupsExtension;
import org.snf4j.tls.extension.SupportedVersionsExtension;
import org.snf4j.tls.handshake.Certificate;
import org.snf4j.tls.handshake.CertificateRequest;
import org.snf4j.tls.handshake.CertificateType;
import org.snf4j.tls.handshake.CertificateVerify;
import org.snf4j.tls.handshake.EncryptedExtensions;
import org.snf4j.tls.handshake.Finished;
import org.snf4j.tls.handshake.HandshakeType;
import org.snf4j.tls.handshake.IClientHello;
import org.snf4j.tls.handshake.IHandshake;
import org.snf4j.tls.handshake.ServerHello;
import org.snf4j.tls.handshake.ServerHelloRandom;
import org.snf4j.tls.record.RecordType;
import org.snf4j.tls.session.SessionTicket;
import org.snf4j.tls.session.UsedSession;
public class ClientHelloConsumer implements IHandshakeConsumer {
@Override
public HandshakeType getType() {
return HandshakeType.CLIENT_HELLO;
}
@Override
public void consume(EngineState state, IHandshake handshake, ByteBuffer[] data, boolean isHRR) throws Alert {
boolean secondClientHello;
switch (state.getState()) {
case SRV_WAIT_1_CH:
secondClientHello = false;
break;
case SRV_WAIT_2_CH:
secondClientHello = true;
break;
default:
throw new UnexpectedMessageAlert("Unexpected ClientHello");
}
IClientHello clientHello = (IClientHello) handshake;
if (clientHello.getLegacyVersion() != EngineDefaults.LEGACY_VERSION) {
throw new ProtocolVersionAlert("Invalid legacy version");
}
ISupportedVersionsExtension versions = find(handshake, ExtensionType.SUPPORTED_VERSIONS);
if (versions == null) {
throw new ProtocolVersionAlert("No support for TLS 1.2 or prior");
}
int negotiatedVersion = -1;
for (int version: versions.getVersions()) {
if (version == 0x0304) {
negotiatedVersion = version;
}
}
if (negotiatedVersion == -1) {
throw new ProtocolVersionAlert("No support for TLS 1.3 by peer");
}
byte[] compressions = clientHello.getLegacyCompressionMethods();
if (compressions.length != 1 || compressions[0] != 0) {
throw new IllegalParameterAlert("Invalid compression methods");
}
IEngineParameters params = state.getParameters();
CipherSuite cipherSuite = findMatch(
params.getCipherSuites(),
clientHello.getCipherSuites());
if (cipherSuite == null) {
throw new HandshakeFailureAlert("Failed to negotiate cipher suite");
}
else if (state.getCipherSuite() != null && !state.getCipherSuite().equals(cipherSuite)) {
throw new IllegalParameterAlert("Negotiated cipher suite mismatch");
}
String protocol = state.getHandler().selectApplicationProtocol(
find(handshake, ExtensionType.APPLICATION_LAYER_PROTOCOL_NEGOTIATION),
params.getApplicationProtocols());
state.setApplicationProtocol(protocol);
state.getHandler().selectedApplicationProtocol(protocol);
IKeyShareExtension keyShare = find(handshake, ExtensionType.KEY_SHARE);
if (keyShare == null) {
throw new MissingExtensionAlert("Missing key_share extension in ClientHello");
}
ISupportedGroupsExtension supportedGroups = find(handshake, ExtensionType.SUPPORTED_GROUPS);
if (supportedGroups == null) {
throw new MissingExtensionAlert("Missing supported_groups extension in ClientHello");
}
ISignatureAlgorithmsExtension signAlgorithms = find(handshake, ExtensionType.SIGNATURE_ALGORITHMS);
if (signAlgorithms == null) {
throw new MissingExtensionAlert("Missing signature_algorithms extension in ClientHello");
}
IPskKeyExchangeModesExtension preSharedKeyModes = find(handshake, ExtensionType.PSK_KEY_EXCHANGE_MODES);
if (preSharedKeyModes != null) {
state.setPskModes(preSharedKeyModes.getModes());
}
boolean earlyData = find(handshake, ExtensionType.EARLY_DATA) != null;
if (secondClientHello && earlyData) {
throw new IllegalParameterAlert("Early data extension not permitted in second ClientHello");
}
IPreSharedKeyExtension preSharedKey = find(handshake, ExtensionType.PRE_SHARED_KEY);
if (preSharedKey != null) {
if (preSharedKey != handshake.getExtensions().get(handshake.getExtensions().size()-1)) {
throw new IllegalParameterAlert("PSK extension not the last extension");
}
if (preSharedKeyModes == null) {
throw new MissingExtensionAlert("Missing psk_key_exchange_modes extension in ClientHello");
}
if (!state.hasPskMode(PskKeyExchangeMode.PSK_DHE_KE)) {
preSharedKey = null;
if (earlyData) {
state.setEarlyDataContext(new EarlyDataContext(
cipherSuite,
true,
state.getHandler().getMaxEarlyDataSize()));
earlyData = false;
}
}
}
else if (earlyData) {
throw new MissingExtensionAlert("Missing pre_shared_key extension in ClientHello");
}
KeyShareEntry keyShareEntry;
if (state.getNamedGroup() != null) {
KeyShareEntry[] entries = keyShare.getEntries();
if (entries.length != 1 || !state.getNamedGroup().equals(entries[0].getNamedGroup())) {
throw new IllegalParameterAlert("Negotiated key share group mismatch");
}
keyShareEntry = entries[0];
}
else {
keyShareEntry = KeyShareEntry.findMatch(
keyShare.getEntries(),
params.getNamedGroups());
}
NamedGroup namedGroup = null;
if (keyShareEntry != null) {
namedGroup = find(supportedGroups.getGroups(), keyShareEntry.getNamedGroup());
if (namedGroup == null) {
throw new IllegalParameterAlert("KeyShareEntry not correspond with supported_groups extension");
}
}
else if (earlyData) {
state.setEarlyDataContext(new EarlyDataContext(
cipherSuite,
true,
state.getHandler().getMaxEarlyDataSize()));
earlyData = false;
}
if (namedGroup == null) {
namedGroup = findMatch(
params.getNamedGroups(),
supportedGroups.getGroups());
if (namedGroup == null) {
throw new HandshakeFailureAlert("Failed to negotiate supported group");
}
}
IServerNameExtension serverName = find(handshake, ExtensionType.SERVER_NAME);
if (serverName == null) {
if (params.isServerNameRequired()) {
throw new MissingExtensionAlert("Missing server_name extension in ClientHello");
}
}
else if (!state.getHandler().verifyServerName(serverName)) {
throw new UnrecognizedNameAlert("Host name '" + serverName.getHostName() + "' is unrecognized");
}
else {
state.setHostName(serverName.getHostName());
}
boolean partInitialization = false;
UsedSession resumed = null;
if (preSharedKey != null) {
if (keyShareEntry != null) {
resumed = state.getHandler().getSessionManager().useSession(
preSharedKey.getOfferedPsks(),
cipherSuite,
earlyData,
protocol);
}
else {
partInitialization = true;
}
}
if (!state.isInitialized()) {
try {
IHash hash = cipherSuite.spec().getHashSpec().getHash();
ITranscriptHash th = state.getTranscriptHash();
if (th == null) {
th = new TranscriptHash(hash.createMessageDigest());
}
if (partInitialization) {
state.setTranscriptHash(th);
state.getTranscriptHash().update(handshake.getType(), data);
}
else {
IHkdf hkdf = new Hkdf(hash.createMac());
state.initialize(new KeySchedule(hkdf, th, cipherSuite.spec()), cipherSuite);
if (resumed != null) {
state.getKeySchedule().deriveEarlySecret(resumed.getTicket().getPsk(), false);
state.getKeySchedule().deriveBinderKey();
int bindersLen = PreSharedKeyExtension.bindersLength(preSharedKey.getOfferedPsks());
int i = data.length;
ByteBuffer[] dupData = new ByteBuffer[i];
ByteBuffer dupBuf;
while(--i >= 0) {
dupBuf = data[i].duplicate();
dupData[i] = dupBuf;
if (bindersLen > 0) {
int len = Math.min(dupBuf.remaining(), bindersLen);
if (len > 0) {
bindersLen -= len;
dupBuf.limit(dupBuf.limit()-len);
}
}
}
byte[] expectedBinder = state.getKeySchedule().computePskBinder(dupData);
byte[] binder = preSharedKey.getOfferedPsks()[resumed.getSelectedIdentity()].getBinder();
state.getKeySchedule().eraseBinderKey();
if (!Arrays.equals(expectedBinder, binder)) {
state.getHandler().getSessionManager().putTicket(
resumed.getSession(),
resumed.getTicket());
throw new DecryptErrorAlert("Invalid PSK binder");
}
state.setSession(resumed.getSession());
}
else {
if (earlyData) {
state.setEarlyDataContext(new EarlyDataContext(
cipherSuite,
true,
state.getHandler().getMaxEarlyDataSize()));
earlyData = false;
}
state.getKeySchedule().deriveEarlySecret();
}
state.getTranscriptHash().update(handshake.getType(), data);
if (earlyData) {
SessionTicket ticket = resumed.getTicket();
if (resumed.getSelectedIdentity() == 0
&& ticket.getCipherSuite().equals(state.getCipherSuite())
&& ticket.forEarlyData(protocol)) {
state.getKeySchedule().deriveEarlyTrafficSecret();
state.getListener().onNewTrafficSecrets(state, RecordType.ZERO_RTT);
state.getKeySchedule().eraseEarlyTrafficSecret();
state.getListener().onNewReceivingTraficKey(state, RecordType.ZERO_RTT);
state.setEarlyDataContext(new EarlyDataContext(
ticket.getCipherSuite(),
ticket.getMaxEarlyDataSize()));
}
else {
state.setEarlyDataContext(new EarlyDataContext(
cipherSuite,
true,
state.getHandler().getMaxEarlyDataSize()));
earlyData = false;
}
}
}
} catch (Alert e) {
throw e;
} catch (Exception e) {
throw new InternalErrorAlert("Failed to create key schedule", e);
}
}
else {
state.getTranscriptHash().update(handshake.getType(), data);
}
if (keyShareEntry == null) {
if (state.getState() == MachineState.SRV_WAIT_2_CH) {
throw new InternalErrorAlert("Unexpected second HelloRetryRquest");
}
List extensions = new ArrayList();
state.setNamedGroup(namedGroup);
extensions.add(new SupportedVersionsExtension(
ISupportedVersionsExtension.Mode.SERVER_HELLO,
negotiatedVersion));
extensions.add(new KeyShareExtension(namedGroup));
ServerHello helloRetryRequest = new ServerHello(
EngineDefaults.LEGACY_VERSION,
ServerHelloRandom.getHelloRetryRequestRandom(),
clientHello.getLegacySessionId(),
cipherSuite,
(byte)0,
extensions);
ConsumerUtil.produceHRR(state, helloRetryRequest, RecordType.INITIAL);
if (clientHello.getLegacySessionId().length > 0) {
state.getListener().produceChangeCipherSpec(state);
}
state.changeState(MachineState.SRV_WAIT_2_CH);
return;
}
state.setVersion(negotiatedVersion);
DelegatedTaskMode taskMode = params.getDelegatedTaskMode();
AbstractEngineTask task = new KeyExchangeTask(
namedGroup,
keyShareEntry.getParsedKey(),
clientHello.getLegacySessionId(),
resumed != null ? resumed.getSelectedIdentity() : -1,
state.getHandler().getSecureRandom());
if (taskMode.all()) {
state.addTask(task);
}
else {
task.run(state);
}
if (resumed != null) {
task = new CertificateTask();
if (!taskMode.all()) {
taskMode = DelegatedTaskMode.NONE;
}
}
else {
ISignatureAlgorithmsExtension signAlgorithmsCert = find(handshake, ExtensionType.SIGNATURE_ALGORITHMS_CERT);
task = new CertificateTask(
state.getHandler().getCertificateSelector(),
new CertificateCriteria(
true,
CertificateType.X509,
state.getHostName(),
signAlgorithms.getSchemes(),
signAlgorithmsCert == null ? null : signAlgorithmsCert.getSchemes(),
params.getSignatureSchemes().clone()
));
}
if (taskMode.certificates()) {
state.changeState(MachineState.SRV_WAIT_TASK);
state.addTask(task);
}
else {
task.run(state);
}
}
static class KeyExchangeTask extends AbstractEngineTask {
private final NamedGroup namedGroup;
private final ParsedKey parsedKey;
private final byte[] legacySessionId;
private final int selectedIdentity;
private final SecureRandom secureRandom;
private volatile byte[] secret;
private volatile byte[] random;
private volatile PublicKey publicKey;
KeyExchangeTask(NamedGroup namedGroup, ParsedKey parsedKey, byte[] legacySessionId, int selectedIdentity, SecureRandom secureRandom) {
this.namedGroup = namedGroup;
this.parsedKey = parsedKey;
this.legacySessionId = legacySessionId;
this.selectedIdentity = selectedIdentity;
this.secureRandom = secureRandom;
}
@Override
public boolean isProducing() {
return true;
}
@Override
public String name() {
return "Key exchange";
}
@Override
void execute() throws Exception {
IKeyExchange keyExchange = namedGroup.spec().getKeyExchange();
KeyPair pair = keyExchange.generateKeyPair(secureRandom);
publicKey = pair.getPublic();
secret = keyExchange.generateSecret(
pair.getPrivate(),
namedGroup.spec().generateKey(parsedKey),
secureRandom);
byte[] random = new byte[32];
secureRandom.nextBytes(random);
this.random = random;
}
@Override
public void finish(EngineState state) throws Alert {
List extensions = new ArrayList();
extensions.add(new SupportedVersionsExtension(
ISupportedVersionsExtension.Mode.SERVER_HELLO,
state.getVersion()));
extensions.add(new KeyShareExtension(
IKeyShareExtension.Mode.SERVER_HELLO,
new KeyShareEntry(namedGroup, publicKey)));
if (selectedIdentity >= 0) {
extensions.add(new PreSharedKeyExtension(selectedIdentity));
}
ServerHello serverHello = new ServerHello(
EngineDefaults.LEGACY_VERSION,
random,
legacySessionId,
state.getCipherSuite(),
(byte)0,
extensions);
ConsumerUtil.prepare(state, serverHello, RecordType.INITIAL, RecordType.HANDSHAKE);
if (legacySessionId.length > 0 && !state.hadState(MachineState.SRV_WAIT_2_CH)) {
state.getListener().prepareChangeCipherSpec(state);
}
try {
state.getKeySchedule().deriveHandshakeSecret(secret);
state.getKeySchedule().eraseEarlySecret();
state.getKeySchedule().deriveHandshakeTrafficSecrets();
}
catch (Exception e) {
throw new InternalErrorAlert("Failed to derive handshake secret", e);
}
state.getListener().onNewTrafficSecrets(state, RecordType.HANDSHAKE);
Arrays.fill(secret, (byte)0);
String hostName = state.getHostName();
extensions = new ArrayList();
if (hostName != null) {
extensions.add(new ServerNameExtension());
}
String protocol = state.getApplicationProtocol();
if (protocol != null) {
extensions.add(new ALPNExtension(protocol));
}
if (state.getEarlyDataContext().getState() == EarlyDataState.PROCESSING) {
extensions.add(new EarlyDataExtension());
}
else {
state.getListener().onNewReceivingTraficKey(state, RecordType.HANDSHAKE);
}
extensions.add(new SupportedGroupsExtension(state.getParameters().getNamedGroups()));
EncryptedExtensions encryptedExtensions = new EncryptedExtensions(extensions);
ConsumerUtil.prepare(state, encryptedExtensions, RecordType.HANDSHAKE);
}
}
static class CertificateTask extends AbstractCertificateTask {
CertificateTask(ICertificateSelector selector, CertificateCriteria criteria) {
super(selector, criteria);
}
CertificateTask() {
}
@Override
public void finish(EngineState state) throws Alert {
MachineState nextState = state.getEarlyDataContext().getState() == EarlyDataState.PROCESSING
? MachineState.SRV_WAIT_EOED
: MachineState.SRV_WAIT_FINISHED;
if (certificates != null) {
IEngineParameters params = state.getParameters();
if (params.getClientAuth() != ClientAuth.NONE) {
List extensions = new ArrayList();
SignatureScheme[] signSchemes;
extensions.add(new SignatureAlgorithmsExtension(params.getSignatureSchemes()));
signSchemes = params.getCertSignatureSchemes();
if (signSchemes != null) {
extensions.add(new SignatureAlgorithmsCertExtension(signSchemes));
}
CertificateRequest certificateRequest = new CertificateRequest(extensions);
ConsumerUtil.prepare(state, certificateRequest, RecordType.HANDSHAKE);
nextState = MachineState.SRV_WAIT_CERT;
}
state.getSessionInfo().localCerts(certificates.getCertificates());
Certificate certificate = new Certificate(new byte[0], certificates.getEntries());
ConsumerUtil.prepare(state, certificate, RecordType.HANDSHAKE);
byte[] signature = ConsumerUtil.sign(state.getTranscriptHash().getHash(HandshakeType.CERTIFICATE, false),
certificates.getAlgorithm(),
certificates.getPrivateKey(),
false,
state.getHandler().getSecureRandom());
CertificateVerify certificateVerify = new CertificateVerify(certificates.getAlgorithm(), signature);
ConsumerUtil.prepare(state, certificateVerify, RecordType.HANDSHAKE);
}
try {
Finished finished = new Finished(state.getKeySchedule().computeServerVerifyData());
ConsumerUtil.prepare(state, finished, RecordType.HANDSHAKE, RecordType.APPLICATION);
state.getKeySchedule().deriveMasterSecret();
state.getKeySchedule().eraseHandshakeSecret();
state.getKeySchedule().deriveApplicationTrafficSecrets();
} catch (Exception e) {
throw new InternalErrorAlert("Failed to compute server verify data", e);
}
state.getListener().onNewTrafficSecrets(state, RecordType.APPLICATION);
state.changeState(nextState);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy