From 6e35c1b32a72016053be510def371bd1a820417a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20=C5=A0pa=C4=8Dek?= Date: Tue, 16 Jul 2024 21:36:30 +0200 Subject: [PATCH] Add keepalive request --- src/client/client.rs | 20 +++++++++++++++++++- src/client/conn.rs | 23 +++++++++++++++++++---- src/client/recv.rs | 2 ++ tests/compat/smoke_test.rs | 2 ++ 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index 47ba34c..faa58b0 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -332,6 +332,24 @@ impl Client { Ok((channel, channel_rx, result.confirm_payload)) } + /// Send a keepalive request. + /// + /// This sends a `keepalive@openssh.com` global request to the server. The server will respond + /// with an error, because this request is not defined, but this should be enough to keep the + /// connection alive on the server. (This is the keepalive mechanism used by the OpenSSH client.) + /// + /// This method will wait until you are authenticated before it sends the request, and it will + /// ignore the response (which should be an error). + pub fn send_keepalive(&self) -> Result<()> { + let (reply_tx, _reply_rx) = oneshot::channel(); + let req = GlobalReq { + request_type: "keepalive@openssh.com".to_owned(), + payload: Bytes::new(), + reply_tx: Some(reply_tx), + }; + self.send_request(req) + } + /// Send a global request (low level API). /// /// This sends `SSH_MSG_GLOBAL_REQUEST` to the server (RFC 4254, section 4). We simply enqueue @@ -404,7 +422,7 @@ pub struct GlobalReq { pub enum GlobalReply { /// Successful reply (`SSH_MSG_REQUEST_SUCCESS`) with response specific payload. Success(Bytes), - /// Failure reply (`SSH_MSG_REQUEST_FAILURE`). + /// Failure reply (`SSH_MSG_REQUEST_FAILURE` or `SSH_MSG_UNIMPLEMENTED`). Failure, } diff --git a/src/client/conn.rs b/src/client/conn.rs index 616125d..66d18d3 100644 --- a/src/client/conn.rs +++ b/src/client/conn.rs @@ -84,6 +84,7 @@ pub(super) struct AcceptedChannelResult { #[derive(Debug)] struct RecvReply { reply_tx: oneshot::Sender, + packet_seq: u32, } @@ -98,9 +99,9 @@ pub(super) fn pump_conn(st: &mut ClientState, cx: &mut Context) -> Result if negotiate::is_ready(st) { if let Some(req) = st.conn_st.send_reqs.pop_front() { - send_global_request(st, &req); + let packet_seq = send_global_request(st, &req); if let Some(reply_tx) = req.reply_tx { - st.conn_st.recv_replies.push_back(RecvReply { reply_tx }); + st.conn_st.recv_replies.push_back(RecvReply { reply_tx, packet_seq }); } return Ok(Pump::Progress) } @@ -520,14 +521,15 @@ pub(super) fn send_request(st: &mut ClientState, req: GlobalReq) -> Result<()> { Ok(()) } -fn send_global_request(st: &mut ClientState, req: &GlobalReq) { +fn send_global_request(st: &mut ClientState, req: &GlobalReq) -> u32 { let mut payload = PacketEncode::new(); payload.put_u8(msg::GLOBAL_REQUEST); payload.put_str(&req.request_type); payload.put_bool(req.reply_tx.is_some()); payload.put_raw(&req.payload); - st.codec.send_pipe.feed_packet(&payload.finish()); + let packet_seq = st.codec.send_pipe.feed_packet(&payload.finish()); log::debug!("sending SSH_MSG_GLOBAL_REQUEST {:?}", req.request_type); + packet_seq } fn recv_request_success(st: &mut ClientState, payload: &mut PacketDecode) -> ResultRecvState { @@ -549,6 +551,19 @@ fn recv_request_failure(st: &mut ClientState) -> ResultRecvState { Ok(None) } +pub(super) fn recv_unimplemented(st: &mut ClientState, packet_seq: u32) -> bool { + if let Some(reply) = st.conn_st.recv_replies.pop_front() { + // tinyssh seems to send `packet_seq` which is off by one from the correct one + if reply.packet_seq == packet_seq || reply.packet_seq + 1 == packet_seq { + let _: Result<_, _> = reply.reply_tx.send(GlobalReply::Failure); + return true + } else { + st.conn_st.recv_replies.push_front(reply); + } + } + false +} + fn packet_len_max_to_len_max(packet_len_max: usize) -> usize { diff --git a/src/client/recv.rs b/src/client/recv.rs index d506c0c..6d821e1 100644 --- a/src/client/recv.rs +++ b/src/client/recv.rs @@ -81,6 +81,8 @@ fn recv_unimplemented(st: &mut ClientState, payload: &mut PacketDecode) -> Resul log::debug!("received SSH_MSG_UNIMPLEMENTED for packet seq {}", packet_seq); if negotiate::recv_unimplemented(st, packet_seq)? { Ok(None) + } else if conn::recv_unimplemented(st, packet_seq) { + Ok(None) } else { Err(Error::PeerRejectedPacket(packet_seq)) } diff --git a/tests/compat/smoke_test.rs b/tests/compat/smoke_test.rs index 38f52a2..15b5344 100644 --- a/tests/compat/smoke_test.rs +++ b/tests/compat/smoke_test.rs @@ -112,6 +112,8 @@ async fn smoke_test(socket: TcpStream, config: makiko::ClientConfig) -> Result<( nursery.spawn(enclose!{(nursery) async move { authenticate_alice(&client).await?; + client.send_keepalive()?; + let (session, mut session_rx) = client.open_session(makiko::ChannelConfig::default()).await?; let (stdout_tx, stdout_rx) = oneshot::channel();