From 85d901da0d03a30ec802b61ea2a4c53850ee6172 Mon Sep 17 00:00:00 2001 From: Rain Date: Fri, 16 Jun 2023 10:02:19 -0700 Subject: [PATCH] [propolis] address cancel-safety issues with InstanceSerialConsoleHelper::recv (#435) See #434 and inline comments. --- Cargo.lock | 1 + bin/propolis-cli/src/main.rs | 76 +++--- lib/propolis-client/Cargo.toml | 4 + lib/propolis-client/src/lib.rs | 450 ++++++++++++++++++++++++++++----- 4 files changed, 435 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 96f58fed6..61ebe269b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3090,6 +3090,7 @@ dependencies = [ name = "propolis-client" version = "0.1.0" dependencies = [ + "async-trait", "base64 0.21.2", "crucible-client-types", "futures", diff --git a/bin/propolis-cli/src/main.rs b/bin/propolis-cli/src/main.rs index 87548d110..a329ab73d 100644 --- a/bin/propolis-cli/src/main.rs +++ b/bin/propolis-cli/src/main.rs @@ -17,7 +17,7 @@ use propolis_client::handmade::{ }, Client, }; -use propolis_client::support::InstanceSerialConsoleHelper; +use propolis_client::support::{InstanceSerialConsoleHelper, WSClientOffset}; use slog::{o, Drain, Level, Logger}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_tungstenite::tungstenite::{ @@ -385,38 +385,43 @@ async fn serial( } msg = ws_console.recv() => { match msg { - Some(Ok(Message::Binary(input))) => { - stdout.write_all(&input).await?; - stdout.flush().await?; - } - Some(Ok(Message::Close(Some(CloseFrame {code, reason})))) => { - eprint!("\r\nConnection closed: {:?}\r\n", code); - match code { - CloseCode::Abnormal - | CloseCode::Error - | CloseCode::Extension - | CloseCode::Invalid - | CloseCode::Policy - | CloseCode::Protocol - | CloseCode::Size - | CloseCode::Unsupported => { - anyhow::bail!("{}", reason); + Some(Ok(msg)) => { + match msg.process().await { + Ok(Message::Binary(input)) => { + stdout.write_all(&input).await?; + stdout.flush().await?; + } + Ok(Message::Close(Some(CloseFrame {code, reason}))) => { + eprint!("\r\nConnection closed: {:?}\r\n", code); + match code { + CloseCode::Abnormal + | CloseCode::Error + | CloseCode::Extension + | CloseCode::Invalid + | CloseCode::Policy + | CloseCode::Protocol + | CloseCode::Size + | CloseCode::Unsupported => { + anyhow::bail!("{}", reason); + } + _ => break, + } } - _ => break, + Ok(Message::Close(None)) => { + eprint!("\r\nConnection closed.\r\n"); + break; + } + // note: migration events via Message::Text are + // already handled within ws_console.recv(), but + // would still be available to match here if we want + // to indicate that it happened to the user + _ => continue, } } - Some(Ok(Message::Close(None))) => { - eprint!("\r\nConnection closed.\r\n"); - break; - } None => { eprint!("\r\nConnection lost.\r\n"); break; } - // note: migration events via Message::Text are already - // handled within ws_console.recv(), but would still be - // available to match here if we want to indicate that it - // happened to the user _ => continue, } } @@ -431,20 +436,13 @@ async fn serial_connect( byte_offset: Option, log: Logger, ) -> anyhow::Result { - let client = propolis_client::Client::new(&format!("http://{}", addr)); - let mut req = client.instance_serial(); + let offset = match byte_offset { + Some(x) if x >= 0 => WSClientOffset::FromStart(x as u64), + Some(x) => WSClientOffset::MostRecent(-x as u64), + None => WSClientOffset::MostRecent(16384), + }; - match byte_offset { - Some(x) if x >= 0 => req = req.from_start(x as u64), - Some(x) => req = req.most_recent(-x as u64), - None => req = req.most_recent(16384), - } - let upgraded = req - .send() - .await - .map_err(|e| anyhow!("Failed to upgrade connection: {}", e))? - .into_inner(); - Ok(InstanceSerialConsoleHelper::new(upgraded, Some(log)).await) + Ok(InstanceSerialConsoleHelper::new(addr, offset, Some(log)).await?) } async fn migrate_instance( diff --git a/lib/propolis-client/Cargo.toml b/lib/propolis-client/Cargo.toml index 628426135..c3207ec64 100644 --- a/lib/propolis-client/Cargo.toml +++ b/lib/propolis-client/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] propolis_types.workspace = true +async-trait.workspace = true reqwest = { workspace = true, features = ["json", "rustls-tls"] } base64.workspace = true futures = { workspace = true, optional = true } @@ -23,6 +24,9 @@ tokio = { workspace = true, features = [ "net" ], optional = true } tokio-tungstenite = { workspace = true, optional = true } crucible-client-types.workspace = true +[dev-dependencies] +tokio = { workspace = true, features = ["test-util"] } + [features] default = [] generated = ["progenitor", "tokio", "tokio-tungstenite", "futures"] diff --git a/lib/propolis-client/src/lib.rs b/lib/propolis-client/src/lib.rs index b4ce8c421..e20a4f235 100644 --- a/lib/propolis-client/src/lib.rs +++ b/lib/propolis-client/src/lib.rs @@ -66,6 +66,8 @@ mod _compat_impls { #[cfg(feature = "generated")] pub mod support { + use std::net::SocketAddr; + use crate::generated::Client as PropolisClient; use crate::handmade::api::InstanceSerialConsoleControlMessage; use futures::{SinkExt, StreamExt}; @@ -79,32 +81,112 @@ pub mod support { use self::tungstenite::http; pub use tokio_tungstenite::{tungstenite, WebSocketStream}; - trait SerialConsoleStream: AsyncRead + AsyncWrite + Unpin + Send {} + /// A trait representing a console stream. + pub trait SerialConsoleStream: + AsyncRead + AsyncWrite + Unpin + Send + { + } impl SerialConsoleStream for T {} + /// Represents a way to build a serial console stream. + #[async_trait::async_trait] + pub(crate) trait SerialConsoleStreamBuilder { + async fn build( + &mut self, + address: SocketAddr, + offset: WSClientOffset, + ) -> Result, WSError>; + } + + /// A serial console builder that uses a Propolis client to build the + /// socket. + #[derive(Debug)] + struct PropolisSerialBuilder {} + + impl PropolisSerialBuilder { + /// Creates a new `PropolisSerialBuilder`. + pub fn new() -> Self { + Self {} + } + } + + #[async_trait::async_trait] + impl SerialConsoleStreamBuilder for PropolisSerialBuilder { + async fn build( + &mut self, + address: SocketAddr, + offset: WSClientOffset, + ) -> Result, WSError> { + let client = PropolisClient::new(&format!("http://{}", address)); + let mut req = client.instance_serial(); + + match offset { + WSClientOffset::FromStart(offset) => { + req = req.from_start(offset); + } + WSClientOffset::MostRecent(offset) => { + req = req.most_recent(offset); + } + } + + let upgraded = req + .send() + .await + .map_err(|e| { + WSError::Http(http::Response::new(Some(e.to_string()))) + })? + .into_inner(); + + Ok(Box::new(upgraded)) + } + } + + pub enum WSClientOffset { + FromStart(u64), + MostRecent(u64), + } + /// This is a trivial abstraction wrapping the websocket connection /// returned by [crate::generated::Client::instance_serial], providing /// the additional functionality of connecting to the new propolis-server /// when an instance is migrated (thus providing the illusion of the /// connection being seamlessly maintained through migration) pub struct InstanceSerialConsoleHelper { + stream_builder: Box, ws_stream: WebSocketStream>, log: Option, } impl InstanceSerialConsoleHelper { - /// Typical use: Pass the [reqwest::Upgraded] connection to the - /// /instance/serial channel, i.e. the value returned by - /// `client.instance_serial().send().await?.into_inner()`. - pub async fn new( - upgraded: T, + /// Creates a new serial console helper by using a Propolis client to + /// connect to the provided address and using the given offset. + /// + /// Returns an error if the helper failed to connect to the address. + pub async fn new( + address: SocketAddr, + offset: WSClientOffset, log: Option, - ) -> Self { - let stream: Box = Box::new(upgraded); + ) -> Result { + let stream_builder = PropolisSerialBuilder::new(); + Self::new_with_builder(stream_builder, address, offset, log).await + } + + // Currently used for testing, and not exposed to clients. + pub(crate) async fn new_with_builder( + mut stream_builder: impl SerialConsoleStreamBuilder + 'static, + address: SocketAddr, + offset: WSClientOffset, + log: Option, + ) -> Result { + let stream = stream_builder.build(address, offset).await?; let ws_stream = WebSocketStream::from_raw_socket(stream, Role::Client, None) .await; - Self { ws_stream, log } + Ok(Self { + stream_builder: Box::new(stream_builder), + ws_stream, + log, + }) } /// Sends the given [WSMessage] to the server. @@ -113,56 +195,131 @@ pub mod support { self.ws_stream.send(input).await } - /// Receive the next [WSMessage] from the server. + /// Receives the next [WSMessage] from the server, holding it in + /// abeyance until it is processed. + /// /// Returns [Option::None] if the connection has been terminated. + /// + /// # Cancel safety + /// + /// This method is cancel-safe and can be used in a `select!` loop + /// without causing any messages to be dropped. However, + /// [InstanceSerialConsoleMessage::process] must be awaited to retrieve + /// the inner [WSMessage], and that portion is not cancel-safe. + pub async fn recv( + &mut self, + ) -> Option, WSError>> { + // Note that ws_stream.next() eventually calls tungstenite's + // read_message. From manual inspection, it looks like read_message + // is written in a cancel-safe fashion so pending packets are + // buffered before being written out. + // + // We currently assume and don't test that ws_stream.next() is + // cancel-safe. That would be a good test to add in the future but + // will require some testing infrastructure to insert delays in the + // I/O stream manually. + let message = self.ws_stream.next().await?; + match message { + Ok(message) => Some(Ok(InstanceSerialConsoleMessage { + helper: self, + message, + })), + Err(error) => Some(Err(error)), + } + } + } + + /// A [`WSMessage`] that has been received but not processed yet. + pub struct InstanceSerialConsoleMessage<'a> { + helper: &'a mut InstanceSerialConsoleHelper, + message: WSMessage, + } + + impl<'a> InstanceSerialConsoleMessage<'a> { + /// Processes this [WSMessage]. + /// /// - [WSMessage::Binary] are character output from the serial console. /// - [WSMessage::Close] is a close frame. /// - [WSMessage::Text] contain metadata, i.e. about a migration, which - /// this function still returns after connecting to the new server - /// in case the application needs to take further action (e.g. log - /// an event, or show a UI indicator that a migration has occurred). - pub async fn recv( - &mut self, - ) -> Option< - Result, //Box>, - > { - let value = self.ws_stream.next().await; - if let Some(Ok(WSMessage::Text(json))) = &value { + /// this function still returns after connecting to the new server in + /// case the application needs to take further action (e.g. log an + /// event, or show a UI indicator that a migration has occurred). + /// + /// # Cancel safety + /// + /// This method is *not* cancel-safe and should *not* be called directly + /// in a `select!` loop. If this future is not awaited to completion, + /// then not only will messages will be dropped, any pending migrations + /// will not complete. + /// + /// Like other non-cancel-safe futures, it is OK to create this future + /// *once*, then call it in a `select!` loop by pinning it and selecting + /// over a `&mut` reference to it. An example is shown in [Resuming an + /// async + /// operation](https://tokio.rs/tokio/tutorial/select#resuming-an-async-operation). + /// + /// # Why this approach? + /// + /// There are two general approaches we can take here to deal with + /// cancel safety: + /// + /// 1. Break apart processing into cancel-safe + /// [`InstanceSerialConsoleHelper::recv`] and non-cancel-safe (this + /// method) sections. This is the approach chosen here. + /// 2. Make all of [`InstanceSerialConsoleHelper::recv`] cancel-safe. + /// This approach was prototyped in [this propolis + /// PR](https://github.com/oxidecomputer/propolis/pull/438), but was + /// not chosen. + /// + /// Why was approach 1 chosen over 2? It comes down to three reasons: + /// + /// 1. This approach is significantly simpler to understand and involves + /// less state fiddling. + /// 2. Once we've received a `Migrating` message, the migration is + /// actually *done*. From there onwards, connecting to the new server + /// should be very quick and it's OK to block on that. + /// 3. Once we've received a `Migrating` message, we shouldn't be + /// sending further messages to the old websocket stream. With + /// approach 2, we'd have to do extra work to buffer up those old + /// messages, then send them after migration is complete. That isn't + /// an issue with approach 1. + /// + /// The current implementation does have an issue where if a migration + /// is happening and we haven't received the `Migrating` message yet, + /// we'll send messages over the old websocket stream. This can be + /// addressed in several ways: + /// + /// - Maintain a sequence number and a local bounded buffer for + /// messages, and include the sequence number in the `Migrating` + /// message. Replay messages starting from the sequence number + /// afterwards. + /// - Buffer messages received during migration on the server rather + /// than the client. + pub async fn process(self) -> Result { + if let WSMessage::Text(json) = &self.message { match serde_json::from_str(json) { Ok(InstanceSerialConsoleControlMessage::Migrating { destination, from_start, }) => { - let client = PropolisClient::new(&format!( - "http://{}", - destination - )); - match client - .instance_serial() - .from_start(from_start) - .send() - .await - { - Ok(resp) => { - let stream: Box = - Box::new(resp.into_inner()); - self.ws_stream = - WebSocketStream::from_raw_socket( - stream, - Role::Client, - None, - ) - .await; - } - Err(e) => { - return Some(Err(WSError::Http( - http::Response::new(Some(e.to_string())), - ))) - } - } + let stream = self + .helper + .stream_builder + .build( + destination, + WSClientOffset::FromStart(from_start), + ) + .await?; + self.helper.ws_stream = + WebSocketStream::from_raw_socket( + stream, + Role::Client, + None, + ) + .await; } Err(e) => { - if let Some(log) = &self.log { + if let Some(log) = &self.helper.log { slog::warn!( log, "Unsupported control message {:?}: {:?}", @@ -174,32 +331,95 @@ pub mod support { } } } - value.map(|x| x.map_err(Into::into)) + + Ok(self.message) } } #[cfg(test)] mod tests { + use super::tungstenite::http; use super::InstanceSerialConsoleControlMessage; use super::InstanceSerialConsoleHelper; use super::Role; + use super::SerialConsoleStream; + use super::SerialConsoleStreamBuilder; + use super::WSClientOffset; use super::WSError; use super::WSMessage; use super::WebSocketStream; use futures::{SinkExt, StreamExt}; + use std::collections::HashMap; + use std::net::IpAddr; + use std::net::Ipv6Addr; use std::net::SocketAddr; + use std::time::Duration; + use tokio::io::AsyncRead; + use tokio::io::AsyncWrite; + use tokio::io::DuplexStream; + use tokio::time::Instant; + + struct DuplexBuilder { + client_conns_and_delays: + HashMap, + } + + impl DuplexBuilder { + pub fn new( + client_conns_and_delays: impl IntoIterator< + Item = (SocketAddr, Duration, DuplexStream), + >, + ) -> Self { + Self { + client_conns_and_delays: client_conns_and_delays + .into_iter() + .map(|(address, delay, stream)| { + (address, (delay, stream)) + }) + .collect(), + } + } + } + + #[async_trait::async_trait] + impl SerialConsoleStreamBuilder for DuplexBuilder { + async fn build( + &mut self, + address: SocketAddr, + // offset is currently unused by this builder. Worth testing in + // the future. + _offset: WSClientOffset, + ) -> Result, WSError> { + if let Some((delay, stream)) = + self.client_conns_and_delays.remove(&address) + { + tokio::time::sleep(delay).await; + Ok(Box::new(stream)) + } else { + Err(WSError::Http(http::Response::new(Some(format!( + "no duplex connection found for address {address}" + ))))) + } + } + } + #[tokio::test] async fn test_connection_helper() { + let address = + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 12000); let (client_conn, server_conn) = tokio::io::duplex(1024); + let stream_builder = + DuplexBuilder::new([(address, Duration::ZERO, client_conn)]); - let mut client = - InstanceSerialConsoleHelper::new(client_conn, None).await; - let mut server = WebSocketStream::from_raw_socket( - server_conn, - Role::Server, + let mut client = InstanceSerialConsoleHelper::new_with_builder( + stream_builder, + address, + WSClientOffset::FromStart(0), None, ) - .await; + .await + .unwrap(); + let mut server = make_ws_server(server_conn).await; let sent = WSMessage::Binary(vec![1, 3, 3, 7]); client.send(sent.clone()).await.unwrap(); @@ -208,7 +428,8 @@ pub mod support { let sent = WSMessage::Binary(vec![2, 4, 6, 8]); server.send(sent.clone()).await.unwrap(); - let received = client.recv().await.unwrap().unwrap(); + let received = + client.recv().await.unwrap().unwrap().process().await.unwrap(); assert_eq!(sent, received); // just check that it *tries* to connect @@ -221,8 +442,123 @@ pub mod support { .unwrap(); let sent = WSMessage::Text(payload); server.send(sent).await.unwrap(); - let received = client.recv().await.unwrap().unwrap_err(); + let received = client + .recv() + .await + .unwrap() + .unwrap() + .process() + .await + .unwrap_err(); assert!(matches!(received, WSError::Http(_))); } + + // start_paused = true means that the durations passed in are used to + // just provide a total ordering for awaits -- we don't actually wait + // that long. + #[tokio::test(start_paused = true)] + async fn test_recv_cancel_safety() { + let address_1 = + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 12000); + let address_2 = + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 14000); + + let (client_conn_1, server_conn_1) = tokio::io::duplex(1024); + let (client_conn_2, server_conn_2) = tokio::io::duplex(1024); + + let stream_builder = DuplexBuilder::new([ + (address_1, Duration::ZERO, client_conn_1), + // Add a delay before connecting to client 2 to test cancel safety. + (address_2, Duration::from_secs(1), client_conn_2), + ]); + + let mut client = InstanceSerialConsoleHelper::new_with_builder( + stream_builder, + address_1, + WSClientOffset::FromStart(0), + None, + ) + .await + .unwrap(); + + let mut server_1 = make_ws_server(server_conn_1).await; + let mut server_2 = make_ws_server(server_conn_2).await; + + let payload = serde_json::to_string( + &InstanceSerialConsoleControlMessage::Migrating { + destination: address_2, + from_start: 0, + }, + ) + .unwrap(); + let migration_message = WSMessage::Text(payload); + + let expected = vec![ + migration_message.clone(), + WSMessage::Binary([5, 6, 7, 8].into()), + WSMessage::Close(None), + ]; + + // Spawn a separate task that feeds values into all the servers with + // a delay. This means that the recv() future is sometimes cancelled + // in the select! loop below, so we can test cancel safety. + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(1)).await; + server_1.send(migration_message).await.unwrap(); + + // This message sent on server 1 is *ignored* because it is sent + // after the "migrating" message. + let sent = WSMessage::Binary([1, 2, 3, 4].into()); + server_1.send(sent).await.unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; + let sent = WSMessage::Binary([5, 6, 7, 8].into()); + server_2.send(sent).await.unwrap(); + + server_2.close(None).await.unwrap(); + }); + + let mut received = Vec::new(); + + // This sends periodic messages which causes client.recv() to be + // canceled sometimes. + let start = Instant::now(); + let mut interval = + tokio::time::interval(Duration::from_millis(250)); + loop { + tokio::select! { + message = client.recv() => { + // XXX At the end of client.recv() we should receive + // None, but in reality we receive a BrokenPipe message, + // why? + let message = message.expect("we terminate this loop before receiving None"); + let message = message + .expect("received a message") + .process() + .await + .expect("no migration error occurred"); + + println!("received message: {message:?}"); + received.push(message.clone()); + + if let WSMessage::Close(_) = message { + break; + } + } + _ = interval.tick() => { + println!("interval tick, {:?} elapsed", start.elapsed()); + } + } + } + + assert_eq!(received, expected); + } + + async fn make_ws_server(conn: S) -> WebSocketStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + WebSocketStream::from_raw_socket(conn, Role::Server, None).await + } } }