net.luminis.quic.packet.LongHeaderPacket Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of kwik Show documentation
Show all versions of kwik Show documentation
A QUIC implementation in Java
/*
* Copyright © 2019, 2020, 2021, 2022, 2023 Peter Doornbosch
*
* This file is part of Kwik, an implementation of the QUIC protocol in Java.
*
* Kwik is free software: you can redistribute it and/or modify it under
* the terms of the GNU Lesser General Public License as published by the
* Free Software Foundation, either version 3 of the License, or (at your option)
* any later version.
*
* Kwik is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
* more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program. If not, see .
*/
package net.luminis.quic.packet;
import net.luminis.quic.crypto.Aead;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.generic.InvalidIntegerEncodingException;
import net.luminis.quic.generic.VariableLengthInteger;
import net.luminis.quic.log.Logger;
import net.luminis.quic.core.*;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
// https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-17.2
public abstract class LongHeaderPacket extends QuicPacket {
private static final int MAX_PACKET_SIZE = 1500;
// Minimal length for a valid packet: type version dcid len dcid scid len scid length packet number payload
private static int MIN_PACKET_LENGTH = 1 + 4 + 1 + 0 + 1 + 0 + 1 + 1 + 1;
protected byte[] sourceConnectionId;
public static boolean isLongHeaderPacket(byte flags, Version quicVersion) {
return (flags & 0b1100_0000) == 0b1100_0000;
}
public static Class determineType(byte flags, Version version) {
int type = (flags & 0x30) >> 4;
if (InitialPacket.isInitial(type, version)) {
return InitialPacket.class;
}
else if (HandshakePacket.isHandshake(type, version)) {
return HandshakePacket.class;
}
else if (RetryPacket.isRetry(type, version)) {
return RetryPacket.class;
}
else if (ZeroRttPacket.isZeroRTT(type, version)) {
return ZeroRttPacket.class;
}
else {
// Impossible, conditions are exhaustive
throw new RuntimeException();
}
}
/**
* Constructs an empty packet for parsing a received one
* @param quicVersion
*/
public LongHeaderPacket(Version quicVersion) {
this.quicVersion = quicVersion;
}
/**
* Constructs a long header packet for sending (client role).
* @param quicVersion
* @param sourceConnectionId
* @param destConnectionId
* @param frame
*/
public LongHeaderPacket(Version quicVersion, byte[] sourceConnectionId, byte[] destConnectionId, QuicFrame frame) {
this.quicVersion = quicVersion;
this.sourceConnectionId = sourceConnectionId;
this.destinationConnectionId = destConnectionId;
this.frames = new ArrayList<>();
if (frame != null) {
this.frames.add(frame);
}
}
/**
* Constructs a long header packet for sending (client role).
* @param quicVersion
* @param sourceConnectionId
* @param destConnectionId
* @param frames
*/
public LongHeaderPacket(Version quicVersion, byte[] sourceConnectionId, byte[] destConnectionId, List frames) {
if (frames == null) {
throw new IllegalArgumentException();
}
this.quicVersion = quicVersion;
this.sourceConnectionId = sourceConnectionId;
this.destinationConnectionId = destConnectionId;
this.frames = frames;
}
@Override
public byte[] generatePacketBytes(Aead aead) {
assert(packetNumber >= 0);
ByteBuffer packetBuffer = ByteBuffer.allocate(MAX_PACKET_SIZE);
generateFrameHeaderInvariant(packetBuffer);
generateAdditionalFields(packetBuffer);
byte[] encodedPacketNumber = encodePacketNumber(packetNumber);
ByteBuffer frameBytes = generatePayloadBytes(encodedPacketNumber.length);
addLength(packetBuffer, encodedPacketNumber.length, frameBytes.limit());
packetBuffer.put(encodedPacketNumber);
protectPacketNumberAndPayload(packetBuffer, encodedPacketNumber.length, frameBytes, 0, aead);
packetBuffer.limit(packetBuffer.position());
packetSize = packetBuffer.limit();
byte[] packetBytes = new byte[packetBuffer.position()];
packetBuffer.rewind();
packetBuffer.get(packetBytes);
packetSize = packetBytes.length;
return packetBytes;
}
@Override
public int estimateLength(int additionalPayload) {
int packetNumberSize = computePacketNumberSize(packetNumber);
int payloadSize = getFrames().stream().mapToInt(f -> f.getFrameLength()).sum() + additionalPayload;
int padding = Integer.max(0,4 - packetNumberSize - payloadSize);
return 1
+ 4
+ 1 + destinationConnectionId.length
+ 1 + sourceConnectionId.length
+ estimateAdditionalFieldsLength()
+ (payloadSize + 1 > 63? 2: 1)
+ computePacketNumberSize(packetNumber)
+ payloadSize
+ padding
// https://www.rfc-editor.org/rfc/rfc9001.html#name-header-protection-sample
// "The ciphersuites defined in [TLS13] - (...) - have 16-byte expansions..."
+ 16;
}
protected void generateFrameHeaderInvariant(ByteBuffer packetBuffer) {
// https://www.rfc-editor.org/rfc/rfc9000.html#name-long-header-packets
// "Long Header Packet {
// Header Form (1) = 1,
// Fixed Bit (1) = 1,
// Long Packet Type (2),
// Type-Specific Bits (4),"
// Version (32),
// Destination Connection ID Length (8),
// Destination Connection ID (0..160),
// Source Connection ID Length (8),
// Source Connection ID (0..160),
// Type-Specific Payload (..),
// }
// Packet type and packet number length
byte flags = encodePacketNumberLength((byte) (0b1100_0000 | (getPacketType() << 4)), packetNumber);
encodePacketNumberLength(flags, packetNumber);
packetBuffer.put(flags);
// Version
packetBuffer.put(quicVersion.getBytes());
// DCID Len
packetBuffer.put((byte) destinationConnectionId.length);
// Destination connection id
packetBuffer.put(destinationConnectionId);
// SCID Len
packetBuffer.put((byte) sourceConnectionId.length);
// Source connection id
packetBuffer.put(sourceConnectionId);
}
protected abstract byte getPacketType();
protected abstract void generateAdditionalFields(ByteBuffer packetBuffer);
protected abstract int estimateAdditionalFieldsLength();
private void addLength(ByteBuffer packetBuffer, int packetNumberLength, int payloadSize) {
int packetLength = payloadSize + 16 + packetNumberLength; // 16 is what encryption adds, note that final length is larger due to adding packet length
VariableLengthInteger.encode(packetLength, packetBuffer);
}
@Override
public void parse(ByteBuffer buffer, Aead aead, long largestPacketNumber, Logger log, int sourceConnectionIdLength) throws DecryptionException, InvalidPacketException {
log.debug("Parsing " + this.getClass().getSimpleName());
if (buffer.position() != 0) {
// parsePacketNumberAndPayload method requires packet to start at 0.
throw new IllegalStateException();
}
if (buffer.remaining() < MIN_PACKET_LENGTH) {
throw new InvalidPacketException();
}
byte flags = buffer.get();
checkPacketType((flags & 0x30) >> 4);
boolean matchingVersion = Version.parse(buffer.getInt()).equals(this.quicVersion);
if (! matchingVersion) {
// https://tools.ietf.org/html/draft-ietf-quic-transport-27#section-5.2
// "... packets are discarded if they indicate a different protocol version than that of the connection..."
throw new InvalidPacketException("Version does not match version of the connection");
}
int dstConnIdLength = buffer.get();
// https://tools.ietf.org/html/draft-ietf-quic-transport-27#section-17.2
// "In QUIC version 1, this value MUST NOT exceed 20. Endpoints that receive a version 1 long header with a
// value larger than 20 MUST drop the packet."
if (dstConnIdLength < 0 || dstConnIdLength > 20) {
throw new InvalidPacketException();
}
if (buffer.remaining() < dstConnIdLength) {
throw new InvalidPacketException();
}
destinationConnectionId = new byte[dstConnIdLength];
buffer.get(destinationConnectionId);
int srcConnIdLength = buffer.get();
if (srcConnIdLength < 0 || srcConnIdLength > 20) {
throw new InvalidPacketException();
}
if (buffer.remaining() < srcConnIdLength) {
throw new InvalidPacketException();
}
sourceConnectionId = new byte[srcConnIdLength];
buffer.get(sourceConnectionId);
log.debug("Destination connection id", destinationConnectionId);
log.debug("Source connection id", sourceConnectionId);
parseAdditionalFields(buffer);
int length;
try {
// "The length of the remainder of the packet (that is, the Packet Number and Payload fields) in bytes"
length = VariableLengthInteger.parse(buffer);
}
catch (IllegalArgumentException | InvalidIntegerEncodingException invalidInt) {
throw new InvalidPacketException();
}
log.debug("Length (PN + payload): " + length);
try {
parsePacketNumberAndPayload(buffer, flags, length, aead, largestPacketNumber, log);
}
finally {
packetSize = buffer.position() - 0;
}
}
@Override
public String toString() {
return "Packet "
+ (isProbe? "P": "")
+ getEncryptionLevel().name().charAt(0) + "|"
+ (packetNumber >= 0? packetNumber: ".") + "|"
+ "L" + "|"
+ (packetSize >= 0? packetSize: ".") + "|"
+ frames.size() + " "
+ frames.stream().map(f -> f.toString()).collect(Collectors.joining(" "));
}
public byte[] getSourceConnectionId() {
return sourceConnectionId;
}
protected void checkPacketType(int type) {
if (type != getPacketType()) {
// Programming error: this method shouldn't have been called if packet is not Initial
throw new RuntimeException();
}
}
protected abstract void parseAdditionalFields(ByteBuffer buffer) throws InvalidPacketException;
}