Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions web-transport-proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ thiserror = "2"
# Just for AsyncRead and AsyncWrite traits
tokio = { version = "1", default-features = false, features = ["io-util"] }
url = "2"

[dev-dependencies]
tokio = { version = "1", features = ["macros", "rt", "io-util"] }
150 changes: 135 additions & 15 deletions web-transport-proto/src/capsule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ use std::sync::Arc;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

use crate::{VarInt, VarIntUnexpectedEnd};
use crate::{VarInt, VarIntUnexpectedEnd, MAX_FRAME_SIZE};

// The spec (draft-ietf-webtrans-http3-06) says the type is 0x2843, which would
// varint-encode to 0x68 0x43. However, actual wire data shows 0x43 0x28 which
// decodes to 808. There may be a discrepancy in implementations or specs.
// Using 0x2843 as specified in the standard.
const CLOSE_WEBTRANSPORT_SESSION_TYPE: u64 = 0x2843;
const MAX_MESSAGE_SIZE: usize = 1024;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Capsule {
Expand All @@ -27,7 +26,7 @@ impl Capsule {
let mut payload = buf.take(length.into_inner() as usize);

// Check declared length first - reject immediately if too large
if payload.limit() > MAX_MESSAGE_SIZE {
if payload.limit() > MAX_FRAME_SIZE as usize {
return Err(CapsuleError::MessageTooLong);
}

Expand All @@ -52,7 +51,7 @@ impl Capsule {
let error_code = payload.get_u32();

let message_len = payload.remaining();
if message_len > MAX_MESSAGE_SIZE {
if message_len > MAX_FRAME_SIZE as usize {
return Err(CapsuleError::MessageTooLong);
}

Expand All @@ -78,22 +77,62 @@ impl Capsule {
}
}

/// Read a capsule from a stream, consuming only the exact bytes of the capsule.
///
/// Returns `Ok(None)` if the stream is cleanly closed (EOF before any bytes).
pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Option<Self>, CapsuleError> {
let mut buf = Vec::new();
loop {
if stream.read_buf(&mut buf).await? == 0 {
if buf.is_empty() {
return Ok(None);
}
let typ = match VarInt::read(stream).await {
Ok(v) => v,
Err(_) => return Ok(None), // Clean EOF
};
let length = VarInt::read(stream)
.await
.map_err(|_| CapsuleError::UnexpectedEnd)?;

let length = length.into_inner();
let typ_val = typ.into_inner();

if length > MAX_FRAME_SIZE {
return Err(CapsuleError::MessageTooLong);
}

let mut payload = stream.take(length);

if let Some(num) = is_grease(typ_val) {
let n = tokio::io::copy(&mut payload, &mut tokio::io::sink()).await?;
if n < length {
return Err(CapsuleError::UnexpectedEnd);
}
return Ok(Some(Self::Grease { num }));
}

let mut limit = std::io::Cursor::new(&buf);
match Self::decode(&mut limit) {
Ok(capsule) => return Ok(Some(capsule)),
Err(CapsuleError::UnexpectedEnd) => continue,
Err(e) => return Err(e),
let mut buf = Vec::with_capacity(length as usize);
payload.read_to_end(&mut buf).await?;

if buf.len() < length as usize {
return Err(CapsuleError::UnexpectedEnd);
}

match typ_val {
CLOSE_WEBTRANSPORT_SESSION_TYPE => {
let mut data = buf.as_slice();
if data.remaining() < 4 {
return Err(CapsuleError::UnexpectedEnd);
}

let error_code = data.get_u32();
let error_message =
String::from_utf8(data.to_vec()).map_err(|_| CapsuleError::InvalidUtf8)?;

Ok(Some(Self::CloseWebTransportSession {
code: error_code,
reason: error_message,
}))
}
_ => Ok(Some(Self::Unknown {
typ,
payload: Bytes::from(buf),
})),
}
}

Expand Down Expand Up @@ -373,4 +412,85 @@ mod tests {
let mut buf = Vec::new();
capsule.encode(&mut buf);
}

#[tokio::test]
async fn test_read_exact_consumption() {
let capsule = Capsule::CloseWebTransportSession {
code: 42,
reason: "bye".to_string(),
};
let mut wire = Vec::new();
capsule.encode(&mut wire);
let trailing = b"leftover";
wire.extend_from_slice(trailing);

let mut cursor = std::io::Cursor::new(wire);
let decoded = Capsule::read(&mut cursor).await.unwrap().unwrap();
assert_eq!(capsule, decoded);

let pos = cursor.position() as usize;
let remaining = &cursor.into_inner()[pos..];
assert_eq!(remaining, trailing);
}

#[tokio::test]
async fn test_read_roundtrip() {
let capsule = Capsule::CloseWebTransportSession {
code: 100,
reason: "test".to_string(),
};
let mut wire = Vec::new();
capsule.encode(&mut wire);

let mut cursor = std::io::Cursor::new(wire);
let decoded = Capsule::read(&mut cursor).await.unwrap().unwrap();
assert_eq!(capsule, decoded);
}

#[tokio::test]
async fn test_read_eof_returns_none() {
let mut cursor = std::io::Cursor::new(Vec::<u8>::new());
let result = Capsule::read(&mut cursor).await.unwrap();
assert!(result.is_none());
}

#[tokio::test]
async fn test_read_rejects_too_large() {
let mut wire = Vec::new();
VarInt::from_u64(0x2843).unwrap().encode(&mut wire); // type
VarInt::from_u64(MAX_FRAME_SIZE + 1)
.unwrap()
.encode(&mut wire); // too large

let mut cursor = std::io::Cursor::new(wire);
let err = Capsule::read(&mut cursor).await.unwrap_err();
assert!(matches!(err, CapsuleError::MessageTooLong));
}

#[tokio::test]
async fn test_read_truncated_payload() {
// CloseWebTransportSession needs at least 4 bytes for error code,
// but the stream is shorter than the declared length.
let mut wire = Vec::new();
VarInt::from_u64(0x2843).unwrap().encode(&mut wire);
VarInt::from_u32(100).encode(&mut wire); // claims 100 bytes
wire.extend_from_slice(b"short"); // only 5 bytes

let mut cursor = std::io::Cursor::new(wire);
let err = Capsule::read(&mut cursor).await.unwrap_err();
assert!(matches!(err, CapsuleError::UnexpectedEnd));
}

#[tokio::test]
async fn test_read_truncated_grease() {
// GREASE capsule type (0x17 = first grease value), claims 50 bytes, only 2 present.
let mut wire = Vec::new();
VarInt::from_u32(0x17).encode(&mut wire);
VarInt::from_u32(50).encode(&mut wire);
wire.extend_from_slice(b"ab");

let mut cursor = std::io::Cursor::new(wire);
let err = Capsule::read(&mut cursor).await.unwrap_err();
assert!(matches!(err, CapsuleError::UnexpectedEnd));
}
}
Loading