All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.luminis.tls.handshake.HandshakeMessage Maven / Gradle / Ivy

Go to download

A (partial) TLS 1.3 implementation in Java, suitable and intended for use in a QUIC implementation.

There is a newer version: 2.3
Show newest version
/*
 * Copyright © 2019, 2020, 2021, 2022, 2023 Peter Doornbosch
 *
 * This file is part of Agent15, an implementation of TLS 1.3 in Java.
 *
 * Agent15 is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Lesser General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or (at your option)
 * any later version.
 *
 * Agent15 is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
 * more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program. If not, see .
 */
package net.luminis.tls.handshake;

import net.luminis.tls.*;
import net.luminis.tls.alert.DecodeErrorException;
import net.luminis.tls.alert.IllegalParameterAlert;
import net.luminis.tls.extension.*;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public abstract class HandshakeMessage extends Message {

    public abstract TlsConstants.HandshakeType getType();

    protected int parseHandshakeHeader(ByteBuffer buffer, TlsConstants.HandshakeType expectedType, int minimumMessageSize) throws DecodeErrorException {
        if (buffer.remaining() < 4) {
            throw new DecodeErrorException("handshake message underflow");
        }
        int handshakeType = buffer.get() & 0xff;
        if (handshakeType != expectedType.value) {
            throw new IllegalStateException();  // i.e. programming error
        }
        int messageDataLength = ((buffer.get() & 0xff) << 16) | ((buffer.get() & 0xff) << 8) | (buffer.get() & 0xff);
        if (4 + messageDataLength < minimumMessageSize) {
            throw new DecodeErrorException(getClass().getSimpleName() + " can't be less than " + minimumMessageSize + " bytes");
        }
        if (buffer.remaining() < messageDataLength) {
            throw new DecodeErrorException("handshake message underflow");
        }
        return messageDataLength;
    }

    public abstract byte[] getBytes();

    static List parseExtensions(ByteBuffer buffer, TlsConstants.HandshakeType context) throws TlsProtocolException {
        return parseExtensions(buffer, context, null);
    }

    static List parseExtensions(ByteBuffer buffer, TlsConstants.HandshakeType context, ExtensionParser customExtensionParser) throws TlsProtocolException {
        if (buffer.remaining() < 2) {
            throw new DecodeErrorException("Extension field must be at least 2 bytes long");
        }
        List extensions = new ArrayList<>();

        int remainingExtensionsLength = buffer.getShort() & 0xffff;
        if (buffer.remaining() < remainingExtensionsLength) {
            throw new DecodeErrorException("Extensions too short");
        }

        while (remainingExtensionsLength >= 4) {
            buffer.mark();
            int extensionType = buffer.getShort() & 0xffff;
            int extensionLength = buffer.getShort() & 0xffff;
            remainingExtensionsLength -= 4;
            buffer.reset();
            if (extensionLength > remainingExtensionsLength) {
                throw new DecodeErrorException("Extension length exceeds extensions length");
            }
            int extensionStartPosition = buffer.position();

            if (extensionType == TlsConstants.ExtensionType.server_name.value) {
                extensions.add(new ServerNameExtension(buffer));
            }
            else if (extensionType == TlsConstants.ExtensionType.supported_groups.value) {
                extensions.add(new SupportedGroupsExtension(buffer));
            }
            else if (extensionType == TlsConstants.ExtensionType.signature_algorithms.value) {
                extensions.add(new SignatureAlgorithmsExtension(buffer));
            }
            else if (extensionType == TlsConstants.ExtensionType.application_layer_protocol_negotiation.value) {
                extensions.add(new ApplicationLayerProtocolNegotiationExtension(buffer));
            }
            else if (extensionType == TlsConstants.ExtensionType.pre_shared_key.value) {
                if (context == TlsConstants.HandshakeType.server_hello) {
                    extensions.add(new ServerPreSharedKeyExtension().parse(buffer));
                }
                else if (context == TlsConstants.HandshakeType.client_hello) {
                    extensions.add(new ClientHelloPreSharedKeyExtension().parse(buffer));
                }
                else {
                    throw new IllegalParameterAlert("Extension not allowed in " + Arrays.stream(TlsConstants.HandshakeType.values()).filter(it -> it.value == context.value).findFirst().get());
                }
            }
            else if (extensionType == TlsConstants.ExtensionType.early_data.value) {
                extensions.add(new EarlyDataExtension(buffer, context));
            }
            else if (extensionType == TlsConstants.ExtensionType.supported_versions.value) {
                extensions.add(new SupportedVersionsExtension(buffer, context));
            }
            else if (extensionType == TlsConstants.ExtensionType.psk_key_exchange_modes.value) {
                extensions.add(new PskKeyExchangeModesExtension(buffer));
            }
            else if (extensionType == TlsConstants.ExtensionType.certificate_authorities.value) {
                extensions.add(new CertificateAuthoritiesExtension(buffer));
            }
            else if (extensionType == TlsConstants.ExtensionType.key_share.value) {
                extensions.add(new KeyShareExtension(buffer, context));
            }
            else {
                Extension extension = null;
                if (customExtensionParser != null) {
                    extension = customExtensionParser.apply(buffer, context);
                }
                if (extension != null) {
                    extensions.add(extension);
                }
                else {
                    Logger.debug("Unsupported extension, type is: " + extensionType);
                    extensions.add(new UnknownExtension().parse(buffer));
                }
            }
            if (buffer.position() - extensionStartPosition != 4 + extensionLength) {
                throw new DecodeErrorException("Incorrect extension length");
            }
            remainingExtensionsLength -= extensionLength;
        }
        return extensions;
    }

    /**
     * Returns the (relative) position of the last extension.
     * @param buffer  data to parse, buffer position should be at the point where extensions start.
     *                on return, the buffer position will be right after the last extension.
     * @return the start position of the last extension, relative to the start of all extensions.
     */
    static public int findPositionLastExtension(ByteBuffer buffer) {
        int extensionsLength = buffer.getShort() & 0xffff;
        int remaining = extensionsLength;
        int lastExtensionStart = 0;
        while (remaining > 4) {
            lastExtensionStart = buffer.position();
            int type = buffer.getShort();
            int length = buffer.getShort() & 0xffff;
            buffer.get(new byte[length]);
            remaining -= (2 + 2 + length);
        }
        return lastExtensionStart;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy