Skip to content

Commit b7511c8

Browse files
authored
multistream-select: Enforce io::error instead of empty protocols (#318)
This PR brings parity between the litep2p mutlistream-select implementation and the libp2p one. There was a mismatch in the litep2p implementation which resulted in decoding empty bytes into `Message::Protocols([ ])`. In contrast, libp2p returns an `io::error` since the message is invalid. While at it have added a few tests to ensure our implementation works as expected cc @paritytech/networking --------- Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>
1 parent 78d934f commit b7511c8

File tree

5 files changed

+130
-47
lines changed

5 files changed

+130
-47
lines changed

src/multistream_select/dialer_select.rs

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ use crate::{
2525
error::{self, Error, ParseError},
2626
multistream_select::{
2727
protocol::{
28-
encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
28+
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
29+
ProtocolError,
2930
},
3031
Negotiated, NegotiationError, Version,
3132
},
@@ -305,7 +306,7 @@ pub enum HandshakeResult {
305306
/// Handshake state.
306307
#[derive(Debug)]
307308
enum HandshakeState {
308-
/// Wainting to receive any response from remote peer.
309+
/// Waiting to receive any response from remote peer.
309310
WaitingResponse,
310311

311312
/// Waiting to receive the actual application protocol from remote peer.
@@ -314,7 +315,7 @@ enum HandshakeState {
314315

315316
/// `multistream-select` dialer handshake state.
316317
#[derive(Debug)]
317-
pub struct DialerState {
318+
pub struct WebRtcDialerState {
318319
/// Proposed main protocol.
319320
protocol: ProtocolName,
320321

@@ -325,16 +326,16 @@ pub struct DialerState {
325326
state: HandshakeState,
326327
}
327328

328-
impl DialerState {
329+
impl WebRtcDialerState {
329330
/// Propose protocol to remote peer.
330331
///
331-
/// Return [`DialerState`] which is used to drive forward the negotiation and an encoded
332+
/// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded
332333
/// `multistream-select` message that contains the protocol proposal for the substream.
333334
pub fn propose(
334335
protocol: ProtocolName,
335336
fallback_names: Vec<ProtocolName>,
336337
) -> crate::Result<(Self, Vec<u8>)> {
337-
let message = encode_multistream_message(
338+
let message = webrtc_encode_multistream_message(
338339
std::iter::once(protocol.clone())
339340
.chain(fallback_names.clone())
340341
.filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok())
@@ -353,7 +354,7 @@ impl DialerState {
353354
))
354355
}
355356

356-
/// Register response to [`DialerState`].
357+
/// Register response to [`WebRtcDialerState`].
357358
pub fn register_response(
358359
&mut self,
359360
payload: Vec<u8>,
@@ -755,7 +756,7 @@ mod tests {
755756
#[test]
756757
fn propose() {
757758
let (mut dialer_state, message) =
758-
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
759+
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
759760
let message = bytes::BytesMut::from(&message[..]).freeze();
760761

761762
let Message::Protocols(protocols) = Message::decode(message).unwrap() else {
@@ -777,7 +778,7 @@ mod tests {
777778

778779
#[test]
779780
fn propose_with_fallback() {
780-
let (mut dialer_state, message) = DialerState::propose(
781+
let (mut dialer_state, message) = WebRtcDialerState::propose(
781782
ProtocolName::from("/13371338/proto/1"),
782783
vec![ProtocolName::from("/sup/proto/1")],
783784
)
@@ -813,7 +814,7 @@ mod tests {
813814
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
814815

815816
let (mut dialer_state, _message) =
816-
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
817+
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
817818

818819
match dialer_state.register_response(bytes.freeze().to_vec()) {
819820
Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
@@ -832,7 +833,7 @@ mod tests {
832833
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
833834

834835
let (mut dialer_state, _message) =
835-
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
836+
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
836837

837838
match dialer_state.register_response(bytes.freeze().to_vec()) {
838839
Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
@@ -842,7 +843,7 @@ mod tests {
842843

843844
#[test]
844845
fn negotiate_main_protocol() {
845-
let message = encode_multistream_message(
846+
let message = webrtc_encode_multistream_message(
846847
vec![Message::Protocol(
847848
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
848849
)]
@@ -851,7 +852,7 @@ mod tests {
851852
.unwrap()
852853
.freeze();
853854

854-
let (mut dialer_state, _message) = DialerState::propose(
855+
let (mut dialer_state, _message) = WebRtcDialerState::propose(
855856
ProtocolName::from("/13371338/proto/1"),
856857
vec![ProtocolName::from("/sup/proto/1")],
857858
)
@@ -860,13 +861,13 @@ mod tests {
860861
match dialer_state.register_response(message.to_vec()) {
861862
Ok(HandshakeResult::Succeeded(negotiated)) =>
862863
assert_eq!(negotiated, ProtocolName::from("/13371338/proto/1")),
863-
_ => panic!("invalid event"),
864+
event => panic!("invalid event {event:?}"),
864865
}
865866
}
866867

867868
#[test]
868869
fn negotiate_fallback_protocol() {
869-
let message = encode_multistream_message(
870+
let message = webrtc_encode_multistream_message(
870871
vec![Message::Protocol(
871872
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
872873
)]
@@ -875,7 +876,7 @@ mod tests {
875876
.unwrap()
876877
.freeze();
877878

878-
let (mut dialer_state, _message) = DialerState::propose(
879+
let (mut dialer_state, _message) = WebRtcDialerState::propose(
879880
ProtocolName::from("/13371338/proto/1"),
880881
vec![ProtocolName::from("/sup/proto/1")],
881882
)

src/multistream_select/listener_select.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ use crate::{
2626
error::{self, Error},
2727
multistream_select::{
2828
protocol::{
29-
encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
29+
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
30+
ProtocolError,
3031
},
3132
Negotiated, NegotiationError,
3233
},
@@ -324,7 +325,7 @@ where
324325
}
325326
}
326327

327-
/// Result of [`listener_negotiate()`].
328+
/// Result of [`webrtc_listener_negotiate()`].
328329
#[derive(Debug)]
329330
pub enum ListenerSelectResult {
330331
/// Requested protocol is available and substream can be accepted.
@@ -348,7 +349,7 @@ pub enum ListenerSelectResult {
348349
/// Parse protocols offered by the remote peer and check if any of the offered protocols match
349350
/// locally available protocols. If a match is found, return an encoded multistream-select
350351
/// response and the negotiated protocol. If parsing fails or no match is found, return an error.
351-
pub fn listener_negotiate<'a>(
352+
pub fn webrtc_listener_negotiate<'a>(
352353
supported_protocols: &'a mut impl Iterator<Item = &'a ProtocolName>,
353354
payload: Bytes,
354355
) -> crate::Result<ListenerSelectResult> {
@@ -382,9 +383,9 @@ pub fn listener_negotiate<'a>(
382383
if protocol.as_ref() == supported.as_bytes() {
383384
return Ok(ListenerSelectResult::Accepted {
384385
protocol: supported.clone(),
385-
message: encode_multistream_message(std::iter::once(Message::Protocol(
386-
protocol,
387-
)))?,
386+
message: webrtc_encode_multistream_message(std::iter::once(
387+
Message::Protocol(protocol),
388+
))?,
388389
});
389390
}
390391
}
@@ -396,7 +397,7 @@ pub fn listener_negotiate<'a>(
396397
);
397398

398399
Ok(ListenerSelectResult::Rejected {
399-
message: encode_multistream_message(std::iter::once(Message::NotAvailable))?,
400+
message: webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))?,
400401
})
401402
}
402403

@@ -405,15 +406,15 @@ mod tests {
405406
use super::*;
406407

407408
#[test]
408-
fn listener_negotiate_works() {
409+
fn webrtc_listener_negotiate_works() {
409410
let mut local_protocols = vec![
410411
ProtocolName::from("/13371338/proto/1"),
411412
ProtocolName::from("/sup/proto/1"),
412413
ProtocolName::from("/13371338/proto/2"),
413414
ProtocolName::from("/13371338/proto/3"),
414415
ProtocolName::from("/13371338/proto/4"),
415416
];
416-
let message = encode_multistream_message(
417+
let message = webrtc_encode_multistream_message(
417418
vec![
418419
Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()),
419420
Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()),
@@ -423,7 +424,7 @@ mod tests {
423424
.unwrap()
424425
.freeze();
425426

426-
match listener_negotiate(&mut local_protocols.iter(), message) {
427+
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
427428
Err(error) => panic!("error received: {error:?}"),
428429
Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"),
429430
Ok(ListenerSelectResult::Accepted { protocol, message }) => {
@@ -441,14 +442,14 @@ mod tests {
441442
ProtocolName::from("/13371338/proto/3"),
442443
ProtocolName::from("/13371338/proto/4"),
443444
];
444-
let message = encode_multistream_message(std::iter::once(Message::Protocols(vec![
445+
let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![
445446
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
446447
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
447448
])))
448449
.unwrap()
449450
.freeze();
450451

451-
match listener_negotiate(&mut local_protocols.iter(), message) {
452+
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
452453
Err(error) => assert!(std::matches!(error, Error::InvalidData)),
453454
_ => panic!("invalid event"),
454455
}
@@ -469,7 +470,7 @@ mod tests {
469470
let message = Message::Header(HeaderLine::V1);
470471
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
471472

472-
match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
473+
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
473474
Err(error) => assert!(std::matches!(
474475
error,
475476
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
@@ -498,7 +499,7 @@ mod tests {
498499
]);
499500
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
500501

501-
match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
502+
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
502503
Err(error) => assert!(std::matches!(
503504
error,
504505
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
@@ -518,7 +519,7 @@ mod tests {
518519
ProtocolName::from("/13371338/proto/3"),
519520
ProtocolName::from("/13371338/proto/4"),
520521
];
521-
let message = encode_multistream_message(
522+
let message = webrtc_encode_multistream_message(
522523
vec![Message::Protocol(
523524
Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(),
524525
)]
@@ -527,12 +528,13 @@ mod tests {
527528
.unwrap()
528529
.freeze();
529530

530-
match listener_negotiate(&mut local_protocols.iter(), message) {
531+
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
531532
Err(error) => panic!("error received: {error:?}"),
532533
Ok(ListenerSelectResult::Rejected { message }) => {
533534
assert_eq!(
534535
message,
535-
encode_multistream_message(std::iter::once(Message::NotAvailable)).unwrap()
536+
webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))
537+
.unwrap()
536538
);
537539
}
538540
Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"),

src/multistream_select/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ mod negotiated;
7676
mod protocol;
7777

7878
pub use crate::multistream_select::{
79-
dialer_select::{dialer_select_proto, DialerSelectFuture, DialerState, HandshakeResult},
79+
dialer_select::{dialer_select_proto, DialerSelectFuture, HandshakeResult, WebRtcDialerState},
8080
listener_select::{
81-
listener_negotiate, listener_select_proto, ListenerSelectFuture, ListenerSelectResult,
81+
listener_select_proto, webrtc_listener_negotiate, ListenerSelectFuture,
82+
ListenerSelectResult,
8283
},
8384
negotiated::{Negotiated, NegotiatedComplete, NegotiationError},
8485
protocol::{HeaderLine, Message, Protocol, ProtocolError},

src/multistream_select/protocol.rs

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ impl Message {
201201
let mut remaining: &[u8] = &msg;
202202
loop {
203203
// A well-formed message must be terminated with a newline.
204-
// TODO: don't do this
205-
if remaining == [b'\n'] || remaining.is_empty() {
204+
if remaining == [b'\n'] {
206205
break;
207206
} else if protocols.len() == MAX_PROTOCOLS {
208207
return Err(ProtocolError::TooManyProtocols);
@@ -228,7 +227,12 @@ impl Message {
228227
}
229228

230229
/// Create `multistream-select` message from an iterator of `Message`s.
231-
pub fn encode_multistream_message(
230+
///
231+
/// # Note
232+
///
233+
/// This is implementation is not compliant with the multistream-select protocol spec.
234+
/// The only purpose of this was to get the `multistream-select` protocol working with smoldot.
235+
pub fn webrtc_encode_multistream_message(
232236
messages: impl IntoIterator<Item = Message>,
233237
) -> crate::Result<BytesMut> {
234238
// encode `/multistream-select/1.0.0` header
@@ -245,6 +249,9 @@ pub fn encode_multistream_message(
245249
header.append(&mut proto_bytes);
246250
}
247251

252+
// For the `Message::Protocols` to be interpreted correctly, it must be followed by a newline.
253+
header.push(b'\n');
254+
248255
Ok(BytesMut::from(&header[..]))
249256
}
250257

@@ -468,3 +475,71 @@ impl From<uvi::decode::Error> for ProtocolError {
468475
Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
469476
}
470477
}
478+
479+
#[cfg(test)]
480+
mod tests {
481+
use super::*;
482+
483+
#[test]
484+
fn test_decode_main_messages() {
485+
// Decode main messages.
486+
let bytes = Bytes::from_static(MSG_MULTISTREAM_1_0);
487+
assert_eq!(
488+
Message::decode(bytes).unwrap(),
489+
Message::Header(HeaderLine::V1)
490+
);
491+
492+
let bytes = Bytes::from_static(MSG_PROTOCOL_NA);
493+
assert_eq!(Message::decode(bytes).unwrap(), Message::NotAvailable);
494+
495+
let bytes = Bytes::from_static(MSG_LS);
496+
assert_eq!(Message::decode(bytes).unwrap(), Message::ListProtocols);
497+
}
498+
499+
#[test]
500+
fn test_decode_empty_message() {
501+
// Empty message should decode to an IoError, not Header::Protocols.
502+
let bytes = Bytes::from_static(b"");
503+
match Message::decode(bytes).unwrap_err() {
504+
ProtocolError::IoError(io) => assert_eq!(io.kind(), io::ErrorKind::InvalidData),
505+
err => panic!("Unexpected error: {:?}", err),
506+
};
507+
}
508+
509+
#[test]
510+
fn test_decode_protocols() {
511+
// Single protocol.
512+
let bytes = Bytes::from_static(b"/protocol-v1\n");
513+
assert_eq!(
514+
Message::decode(bytes).unwrap(),
515+
Message::Protocol(Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap())
516+
);
517+
518+
// Multiple protocols.
519+
let expected = Message::Protocols(vec![
520+
Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap(),
521+
Protocol::try_from(Bytes::from_static(b"/protocol-v2")).unwrap(),
522+
]);
523+
let mut encoded = BytesMut::new();
524+
expected.encode(&mut encoded).unwrap();
525+
526+
// `\r` is the length of the protocol names.
527+
let bytes = Bytes::from_static(b"\r/protocol-v1\n\r/protocol-v2\n\n");
528+
assert_eq!(encoded, bytes);
529+
530+
assert_eq!(
531+
Message::decode(bytes).unwrap(),
532+
Message::Protocols(vec![
533+
Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap(),
534+
Protocol::try_from(Bytes::from_static(b"/protocol-v2")).unwrap(),
535+
])
536+
);
537+
538+
// Check invalid length.
539+
let bytes = Bytes::from_static(b"\r/v1\n\n");
540+
assert_eq!(
541+
Message::decode(bytes).unwrap_err(),
542+
ProtocolError::InvalidMessage
543+
);
544+
}
545+
}

0 commit comments

Comments
 (0)