Skip to content

Commit 568b2c6

Browse files
committed
Use ProtocolName for ALPN protocol pervasively
1 parent 229dfe2 commit 568b2c6

File tree

9 files changed

+41
-28
lines changed

9 files changed

+41
-28
lines changed

rustls/src/client/client_conn.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::error::Error;
1919
use crate::kernel::KernelConnection;
2020
use crate::log::trace;
2121
use crate::msgs::enums::NamedGroup;
22-
use crate::msgs::handshake::ClientExtension;
22+
use crate::msgs::handshake::{ClientExtension, ProtocolName};
2323
use crate::msgs::persist;
2424
use crate::suites::{ExtractedSecrets, SupportedCipherSuite};
2525
use crate::sync::Arc;
@@ -851,6 +851,10 @@ impl ConnectionCore<ClientConnectionData> {
851851
common_state.enable_secret_extraction = config.enable_secret_extraction;
852852
common_state.fips = config.fips();
853853
let mut data = ClientConnectionData::new();
854+
let alpn_protocols = alpn_protocols
855+
.into_iter()
856+
.map(ProtocolName::from)
857+
.collect();
854858

855859
let mut cx = hs::ClientContext {
856860
common: &mut common_state,

rustls/src/client/common.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use alloc::vec::Vec;
44
use super::ResolvesClientCert;
55
use crate::log::{debug, trace};
66
use crate::msgs::enums::ExtensionType;
7-
use crate::msgs::handshake::{CertificateChain, DistinguishedName, ServerExtension};
7+
use crate::msgs::handshake::{CertificateChain, DistinguishedName, ProtocolName, ServerExtension};
88
use crate::sync::Arc;
99
use crate::{SignatureScheme, compress, sign};
1010

@@ -35,14 +35,14 @@ impl<'a> ServerCertDetails<'a> {
3535
}
3636

3737
pub(super) struct ClientHelloDetails {
38-
pub(super) alpn_protocols: Vec<Vec<u8>>,
38+
pub(super) alpn_protocols: Vec<ProtocolName>,
3939
pub(super) sent_extensions: Vec<ExtensionType>,
4040
pub(super) extension_order_seed: u16,
4141
pub(super) offered_cert_compression: bool,
4242
}
4343

4444
impl ClientHelloDetails {
45-
pub(super) fn new(alpn_protocols: Vec<Vec<u8>>, extension_order_seed: u16) -> Self {
45+
pub(super) fn new(alpn_protocols: Vec<ProtocolName>, extension_order_seed: u16) -> Self {
4646
Self {
4747
alpn_protocols,
4848
sent_extensions: Vec::new(),

rustls/src/client/hs.rs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ fn find_session(
9999

100100
pub(super) fn start_handshake(
101101
server_name: ServerName<'static>,
102-
alpn_protocols: Vec<Vec<u8>>,
102+
alpn_protocols: Vec<ProtocolName>,
103103
extra_exts: Vec<ClientExtension>,
104104
config: Arc<ClientConfig>,
105105
cx: &mut ClientContext<'_>,
@@ -359,12 +359,7 @@ fn emit_client_hello_for_retry(
359359
// Add ALPN extension if we have any protocols
360360
if !input.hello.alpn_protocols.is_empty() {
361361
exts.push(ClientExtension::Protocols(
362-
input
363-
.hello
364-
.alpn_protocols
365-
.iter()
366-
.map(|proto| ProtocolName::from(proto.clone()))
367-
.collect::<Vec<_>>(),
362+
input.hello.alpn_protocols.clone(),
368363
));
369364
}
370365

@@ -668,10 +663,10 @@ fn prepare_resumption<'a>(
668663

669664
pub(super) fn process_alpn_protocol(
670665
common: &mut CommonState,
671-
offered_protocols: &[Vec<u8>],
672-
proto: Option<&[u8]>,
666+
offered_protocols: &[ProtocolName],
667+
selected: Option<&ProtocolName>,
673668
) -> Result<(), Error> {
674-
common.alpn_protocol = proto.map(ToOwned::to_owned);
669+
common.alpn_protocol = selected.map(ToOwned::to_owned);
675670

676671
if let Some(alpn_protocol) = &common.alpn_protocol {
677672
if !offered_protocols.contains(alpn_protocol) {

rustls/src/common_state.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::msgs::base::Payload;
1414
use crate::msgs::codec::Codec;
1515
use crate::msgs::enums::{AlertLevel, KeyUpdateRequest};
1616
use crate::msgs::fragmenter::MessageFragmenter;
17-
use crate::msgs::handshake::{CertificateChain, HandshakeMessagePayload};
17+
use crate::msgs::handshake::{CertificateChain, HandshakeMessagePayload, ProtocolName};
1818
use crate::msgs::message::{
1919
Message, MessagePayload, OutboundChunks, OutboundOpaqueMessage, OutboundPlainMessage,
2020
PlainMessage,
@@ -35,7 +35,7 @@ pub struct CommonState {
3535
pub(crate) record_layer: record_layer::RecordLayer,
3636
pub(crate) suite: Option<SupportedCipherSuite>,
3737
pub(crate) kx_state: KxState,
38-
pub(crate) alpn_protocol: Option<Vec<u8>>,
38+
pub(crate) alpn_protocol: Option<ProtocolName>,
3939
pub(crate) aligned_handshake: bool,
4040
pub(crate) may_send_application_data: bool,
4141
pub(crate) may_receive_application_data: bool,

rustls/src/msgs/handshake.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,20 @@ wrapped_payload!(
360360
pub(crate) struct ProtocolName, PayloadU8<NonEmpty>,
361361
);
362362

363+
impl PartialEq for ProtocolName {
364+
fn eq(&self, other: &Self) -> bool {
365+
self.0 == other.0
366+
}
367+
}
368+
369+
impl Deref for ProtocolName {
370+
type Target = [u8];
371+
372+
fn deref(&self) -> &Self::Target {
373+
self.as_ref()
374+
}
375+
}
376+
363377
/// RFC7301: `ProtocolName protocol_name_list<2..2^16-1>`
364378
impl TlsListElement for ProtocolName {
365379
const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
@@ -372,8 +386,8 @@ impl TlsListElement for ProtocolName {
372386
pub(crate) struct SingleProtocolName(ProtocolName);
373387

374388
impl SingleProtocolName {
375-
pub(crate) fn new(bytes: Vec<u8>) -> Self {
376-
Self(ProtocolName::from(bytes))
389+
pub(crate) fn new(single: ProtocolName) -> Self {
390+
Self(single)
377391
}
378392

379393
const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
@@ -2224,10 +2238,10 @@ pub(crate) trait HasServerExtensions {
22242238
.find(|x| x.ext_type() == ext)
22252239
}
22262240

2227-
fn alpn_protocol(&self) -> Option<&[u8]> {
2241+
fn alpn_protocol(&self) -> Option<&ProtocolName> {
22282242
let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
22292243
match ext {
2230-
ServerExtension::Protocols(protos) => Some(protos.as_ref()),
2244+
ServerExtension::Protocols(protos) => Some(&protos.0),
22312245
_ => None,
22322246
}
22332247
}

rustls/src/msgs/handshake_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ fn sample_server_hello_payload() -> ServerHelloPayload {
10061006
ServerExtension::ServerNameAck,
10071007
ServerExtension::SessionTicketAck,
10081008
ServerExtension::RenegotiationInfo(PayloadU8::new(vec![0])),
1009-
ServerExtension::Protocols(SingleProtocolName::new(vec![0])),
1009+
ServerExtension::Protocols(SingleProtocolName::new(ProtocolName::from(vec![0]))),
10101010
ServerExtension::KeyShare(KeyShareEntry::new(NamedGroup::X25519, &[1, 2, 3][..])),
10111011
ServerExtension::PresharedKey(3),
10121012
ServerExtension::ExtendedMasterSecretAck,

rustls/src/msgs/persist.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ use crate::enums::{CipherSuite, ProtocolVersion};
99
use crate::error::InvalidMessage;
1010
use crate::msgs::base::{MaybeEmpty, PayloadU8, PayloadU16};
1111
use crate::msgs::codec::{Codec, Reader};
12-
use crate::msgs::handshake::CertificateChain;
1312
#[cfg(feature = "tls12")]
1413
use crate::msgs::handshake::SessionId;
14+
use crate::msgs::handshake::{CertificateChain, ProtocolName};
1515
use crate::sync::{Arc, Weak};
1616
#[cfg(feature = "tls12")]
1717
use crate::tls12::Tls12CipherSuite;
@@ -399,7 +399,7 @@ impl ServerSessionValue {
399399
cs: CipherSuite,
400400
ms: &[u8],
401401
client_cert_chain: Option<CertificateChain<'static>>,
402-
alpn: Option<Vec<u8>>,
402+
alpn: Option<ProtocolName>,
403403
application_data: Vec<u8>,
404404
creation_time: UnixTime,
405405
age_obfuscation_offset: u32,
@@ -411,7 +411,7 @@ impl ServerSessionValue {
411411
master_secret: Zeroizing::new(PayloadU8::new(ms.to_vec())),
412412
extended_ms: false,
413413
client_cert_chain,
414-
alpn: alpn.map(PayloadU8::new),
414+
alpn: alpn.map(|p| PayloadU8::new(p.as_ref().to_vec())),
415415
application_data: PayloadU16::new(application_data),
416416
creation_time_sec: creation_time.as_secs(),
417417
age_obfuscation_offset,

rustls/src/server/hs.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ use crate::msgs::enums::{Compression, ExtensionType, NamedGroup};
2121
#[cfg(feature = "tls12")]
2222
use crate::msgs::handshake::SessionId;
2323
use crate::msgs::handshake::{
24-
ClientHelloPayload, HandshakePayload, KeyExchangeAlgorithm, Random, ServerExtension,
25-
ServerNamePayload, SingleProtocolName,
24+
ClientHelloPayload, HandshakePayload, KeyExchangeAlgorithm, ProtocolName, Random,
25+
ServerExtension, ServerNamePayload, SingleProtocolName,
2626
};
2727
use crate::msgs::message::{Message, MessagePayload};
2828
use crate::msgs::persist;
@@ -89,7 +89,7 @@ impl ExtensionProcessing {
8989
.iter()
9090
.any(|theirs| theirs.as_ref() == ours.as_slice())
9191
})
92-
.cloned();
92+
.map(|bytes| ProtocolName::from(bytes.clone()));
9393
if let Some(selected_protocol) = &cx.common.alpn_protocol {
9494
debug!("Chosen ALPN protocol {selected_protocol:?}");
9595
self.exts

rustls/src/server/tls13.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ mod client_hello {
652652
&& resume.is_fresh()
653653
&& Some(resume.version) == cx.common.negotiated_version
654654
&& resume.cipher_suite == suite.common.suite
655-
&& resume.alpn.as_ref().map(|x| &x.0) == cx.common.alpn_protocol.as_ref();
655+
&& resume.alpn.as_ref().map(|p| &p.0[..]) == cx.common.alpn_protocol.as_deref();
656656

657657
if early_data_configured && early_data_possible && !cx.data.early_data.was_rejected() {
658658
EarlyDataDecision::Accepted

0 commit comments

Comments
 (0)