diff --git a/Cargo.lock b/Cargo.lock index b72d1c3..6090a9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -956,7 +956,7 @@ dependencies = [ [[package]] name = "tokio-tungstenite" version = "0.15.0" -source = "git+https://github.com/kazk/tokio-tungstenite?branch=permessage-deflate#219d6ce9a27e6fd0a2d36e352591d3c02920e16e" +source = "git+https://github.com/kazk/tokio-tungstenite?branch=permessage-deflate#4133a28b800529ec920ebe052496785e1cf37b3a" dependencies = [ "futures-util", "log", @@ -1069,7 +1069,7 @@ checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" [[package]] name = "tungstenite" version = "0.15.0" -source = "git+https://github.com/kazk/tungstenite-rs?branch=permessage-deflate#54acf30635482227e748d53293f8f4672300cf0b" +source = "git+https://github.com/kazk/tungstenite-rs?branch=permessage-deflate#734a0b983070d7f965c984d5489c5be038e01530" dependencies = [ "base64", "byteorder", @@ -1168,7 +1168,7 @@ dependencies = [ [[package]] name = "warp" version = "0.3.1" -source = "git+https://github.com/kazk/warp?branch=permessage-deflate#d6ea7ccbe26207edeac81dfe7527ff59962d1ac8" +source = "git+https://github.com/kazk/warp?branch=permessage-deflate#a7425b320916e6f25e4f77b09ddd24490ccf62a4" dependencies = [ "bytes", "futures", diff --git a/Cargo.toml b/Cargo.toml index 104bb6f..742a7b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ serde = { version = "1.0.126", features = ["derive"] } serde_json = "1.0.64" url = "2.2.2" -tokio = { version = "1.6.1", features = ["fs", "process", "macros", "rt", "rt-multi-thread"] } +tokio = { version = "1.6.1", features = ["fs", "process", "macros", "rt", "rt-multi-thread", "time"] } tokio-util = { version = "0.6.7", features = ["codec"] } warp = { git = "https://github.com/kazk/warp", branch = "permessage-deflate", default-features = false, features = ["websocket"] } diff --git a/src/api/proxy.rs b/src/api/proxy.rs index fbbb6c7..8034722 100644 --- a/src/api/proxy.rs +++ b/src/api/proxy.rs @@ -2,7 +2,7 @@ use std::{convert::Infallible, process::Stdio, str::FromStr}; use futures_util::{ future::{select, Either}, - SinkExt, StreamExt, + stream, SinkExt, StreamExt, }; use tokio::{fs, process::Command}; use url::Url; @@ -108,7 +108,19 @@ async fn connected( let mut server_send = lsp::framed::writer(server.stdin.take().unwrap()); let mut server_recv = lsp::framed::reader(server.stdout.take().unwrap()); let (mut client_send, client_recv) = ws.split(); - let mut client_recv = client_recv.filter_map(filter_map_warp_ws_message).boxed(); + let client_recv = client_recv + .filter_map(filter_map_warp_ws_message) + // Chain this with `Done` so we know when the client disconnects + .chain(stream::once(async { Ok(Message::Done) })); + // Tick every 30s so we can ping the client to keep the connection alive + let ticks = stream::unfold( + tokio::time::interval(std::time::Duration::from_secs(30)), + |mut interval| async move { + interval.tick().await; + Some((Ok(Message::Tick), interval)) + }, + ); + let mut client_recv = stream::select(client_recv, ticks).boxed(); let mut client_msg = client_recv.next(); let mut server_msg = server_recv.next(); @@ -145,15 +157,26 @@ async fn connected( tracing::info!("received Close message"); } + // Ping the client to keep the connection alive + Some(Ok(Message::Tick)) => { + tracing::debug!("pinging the client"); + client_send.send(warp::ws::Message::ping(vec![])).await?; + } + + // Connection closed + Some(Ok(Message::Done)) => { + tracing::info!("connection closed"); + break; + } + // WebSocket Error Some(Err(err)) => { tracing::error!("websocket error: {}", err); } - // Connection closed None => { - tracing::info!("connection closed"); - break; + // Unreachable because of the interval stream + unreachable!("should never yield None"); } } @@ -206,6 +229,8 @@ async fn connected( } // Type to describe a message from the client conveniently. +#[allow(clippy::large_enum_variant)] +#[allow(clippy::enum_variant_names)] enum Message { // Valid LSP message Message(lsp::Message), @@ -213,6 +238,11 @@ enum Message { Invalid(String), // Close message Close, + // Ping the client to keep the connection alive. + // Note that this is from the interval stream and not actually from client. + Tick, + // Client disconnected. Necessary because the combined stream is infinite. + Done, } // Parse the message and ignore anything we don't care. diff --git a/src/lsp/framed/parser.rs b/src/lsp/framed/parser.rs index 7478a27..9439e32 100644 --- a/src/lsp/framed/parser.rs +++ b/src/lsp/framed/parser.rs @@ -24,7 +24,7 @@ pub fn parse_message(input: &[u8]) -> IResult<&[u8], &[u8]> { let header = terminated(terminated(content_len, opt(content_type)), crlf); - let header = map_res(header, |s: &[u8]| str::from_utf8(s)); + let header = map_res(header, str::from_utf8); let length = map_res(header, |s: &str| s.parse::()); let mut message = length_data(length); diff --git a/src/lsp/mod.rs b/src/lsp/mod.rs index ff5b939..8140eb2 100644 --- a/src/lsp/mod.rs +++ b/src/lsp/mod.rs @@ -18,6 +18,7 @@ use types::Unknown; #[derive(Clone, Debug, PartialEq, Deserialize)] #[serde(untagged)] +#[allow(clippy::large_enum_variant)] pub enum Message { Request(Request), diff --git a/src/lsp/request.rs b/src/lsp/request.rs index 034f104..c8a82da 100644 --- a/src/lsp/request.rs +++ b/src/lsp/request.rs @@ -10,6 +10,7 @@ use super::types::Id; /// [Request message]: https://microsoft.github.io/language-server-protocol/specifications/specification-current/#requestMessage #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(tag = "method")] +#[allow(clippy::large_enum_variant)] pub enum Request { // To Server /// > The [initialize] request is sent as the first request from the client