net.luminis.tls.handshake.HandshakeMessage Maven / Gradle / Ivy
/*
* Copyright © 2019, 2020, 2021, 2022, 2023, 2024 Peter Doornbosch
*
* This file is part of Agent15, an implementation of TLS 1.3 in Java.
*
* Agent15 is free software: you can redistribute it and/or modify it under
* the terms of the GNU Lesser General Public License as published by the
* Free Software Foundation, either version 3 of the License, or (at your option)
* any later version.
*
* Agent15 is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
* more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*/
package net.luminis.tls.handshake;
import net.luminis.tls.TlsConstants;
import net.luminis.tls.TlsProtocolException;
import net.luminis.tls.alert.DecodeErrorException;
import net.luminis.tls.alert.IllegalParameterAlert;
import net.luminis.tls.extension.*;
import net.luminis.tls.log.Logger;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* https://datatracker.ietf.org/doc/html/rfc8446#section-4
*/
public abstract class HandshakeMessage {
public abstract TlsConstants.HandshakeType getType();
protected int parseHandshakeHeader(ByteBuffer buffer, TlsConstants.HandshakeType expectedType, int minimumMessageSize) throws DecodeErrorException {
if (buffer.remaining() < 4) {
throw new DecodeErrorException("handshake message underflow");
}
int handshakeType = buffer.get() & 0xff;
if (handshakeType != expectedType.value) {
throw new IllegalStateException(); // i.e. programming error
}
int messageDataLength = ((buffer.get() & 0xff) << 16) | ((buffer.get() & 0xff) << 8) | (buffer.get() & 0xff);
if (4 + messageDataLength < minimumMessageSize) {
throw new DecodeErrorException(getClass().getSimpleName() + " can't be less than " + minimumMessageSize + " bytes");
}
if (buffer.remaining() < messageDataLength) {
throw new DecodeErrorException("handshake message underflow");
}
return messageDataLength;
}
public abstract byte[] getBytes();
static List parseExtensions(ByteBuffer buffer, TlsConstants.HandshakeType context) throws TlsProtocolException {
return parseExtensions(buffer, context, null);
}
static List parseExtensions(ByteBuffer buffer, TlsConstants.HandshakeType context, ExtensionParser customExtensionParser) throws TlsProtocolException {
if (buffer.remaining() < 2) {
throw new DecodeErrorException("Extension field must be at least 2 bytes long");
}
List extensions = new ArrayList<>();
int remainingExtensionsLength = buffer.getShort() & 0xffff;
if (buffer.remaining() < remainingExtensionsLength) {
throw new DecodeErrorException("Extensions too short");
}
while (remainingExtensionsLength >= 4) {
buffer.mark();
int extensionType = buffer.getShort() & 0xffff;
int extensionLength = buffer.getShort() & 0xffff;
remainingExtensionsLength -= 4;
buffer.reset();
if (extensionLength > remainingExtensionsLength) {
throw new DecodeErrorException("Extension length exceeds extensions length");
}
int extensionStartPosition = buffer.position();
if (extensionType == TlsConstants.ExtensionType.server_name.value) {
extensions.add(new ServerNameExtension(buffer));
}
else if (extensionType == TlsConstants.ExtensionType.supported_groups.value) {
extensions.add(new SupportedGroupsExtension(buffer));
}
else if (extensionType == TlsConstants.ExtensionType.signature_algorithms.value) {
extensions.add(new SignatureAlgorithmsExtension(buffer));
}
else if (extensionType == TlsConstants.ExtensionType.application_layer_protocol_negotiation.value) {
extensions.add(new ApplicationLayerProtocolNegotiationExtension(buffer));
}
else if (extensionType == TlsConstants.ExtensionType.pre_shared_key.value) {
if (context == TlsConstants.HandshakeType.server_hello) {
extensions.add(new ServerPreSharedKeyExtension().parse(buffer));
}
else if (context == TlsConstants.HandshakeType.client_hello) {
extensions.add(new ClientHelloPreSharedKeyExtension().parse(buffer));
}
else {
// https://datatracker.ietf.org/doc/html/rfc8446#section-4.2
// "If an implementation receives an extension which it recognizes and which is not specified for
// the message in which it appears, it MUST abort the handshake with an "illegal_parameter" alert."
throw new IllegalParameterAlert("Extension not allowed in " + Arrays.stream(TlsConstants.HandshakeType.values()).filter(it -> it.value == context.value).findFirst().get());
}
}
else if (extensionType == TlsConstants.ExtensionType.early_data.value) {
extensions.add(new EarlyDataExtension(buffer, context));
}
else if (extensionType == TlsConstants.ExtensionType.supported_versions.value) {
extensions.add(new SupportedVersionsExtension(buffer, context));
}
else if (extensionType == TlsConstants.ExtensionType.psk_key_exchange_modes.value) {
extensions.add(new PskKeyExchangeModesExtension(buffer));
}
else if (extensionType == TlsConstants.ExtensionType.certificate_authorities.value) {
extensions.add(new CertificateAuthoritiesExtension(buffer));
}
else if (extensionType == TlsConstants.ExtensionType.key_share.value) {
extensions.add(new KeyShareExtension(buffer, context));
}
else {
Extension extension = null;
if (customExtensionParser != null) {
extension = customExtensionParser.apply(buffer, context);
}
if (extension != null) {
extensions.add(extension);
}
else {
Logger.debug("Unsupported extension, type is: " + extensionType);
extensions.add(new UnknownExtension().parse(buffer));
}
}
if (buffer.position() - extensionStartPosition != 4 + extensionLength) {
throw new DecodeErrorException("Incorrect extension length");
}
remainingExtensionsLength -= extensionLength;
}
return extensions;
}
/**
* Returns the (relative) position of the last extension.
* @param buffer data to parse, buffer position should be at the point where extensions start.
* on return, the buffer position will be right after the last extension.
* @return the start position of the last extension, relative to the start of all extensions.
*/
static public int findPositionLastExtension(ByteBuffer buffer) {
int extensionsLength = buffer.getShort() & 0xffff;
int remaining = extensionsLength;
int lastExtensionStart = 0;
while (remaining > 4) {
lastExtensionStart = buffer.position();
int type = buffer.getShort();
int length = buffer.getShort() & 0xffff;
buffer.get(new byte[length]);
remaining -= (2 + 2 + length);
}
return lastExtensionStart;
}
}