Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions COMMIT_MESSAGE_ISSUE_378.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fix(rmcp-client): label phase errors, retry transient init (#378)
8 changes: 8 additions & 0 deletions PR_BODY_ISSUE_378.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
## Summary
- add explicit phase context to initialize/list/call MCP failures so logs show where the request died
- retry the Streamable HTTP initialize handshake once on obvious transient network errors before surfacing failure
- cover the new helpers with unit tests for phase labeling and retry gating

## Testing
- cargo test -p code-rmcp-client
- ./build-fast.sh
269 changes: 230 additions & 39 deletions code-rs/rmcp-client/src/rmcp_client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::collections::HashMap;
use std::error::Error as StdError;
use std::ffi::OsString;
use std::fmt;
use std::future::Future;
use std::io;
use std::process::Stdio;
use std::sync::Arc;
Expand All @@ -24,9 +27,14 @@ use rmcp::service::{self};
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::child_process::TokioChildProcess;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use reqwest::Error as ReqwestError;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Command;

const INITIALIZE_RETRY_BASE_DELAY_MS: u64 = 200;
const INITIALIZE_RETRY_MAX_DELAY_MS: u64 = 1_600;
const INITIALIZE_MAX_RETRIES: usize = 3;
use tokio::sync::Mutex;
use tokio::time;
use tracing::info;
Expand All @@ -41,7 +49,11 @@ use crate::utils::run_with_timeout;

enum PendingTransport {
ChildProcess(TokioChildProcess),
StreamableHttp(StreamableHttpClientTransport<reqwest::Client>),
StreamableHttp {
transport: StreamableHttpClientTransport<reqwest::Client>,
url: String,
bearer_token: Option<String>,
},
}

enum ClientState {
Expand All @@ -53,6 +65,23 @@ enum ClientState {
},
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum Phase {
Initialize,
ListTools,
CallTool,
}

impl Phase {
fn as_str(self) -> &'static str {
match self {
Phase::Initialize => "initialize",
Phase::ListTools => "list_tools",
Phase::CallTool => "call_tool",
}
}
}

/// MCP client implemented on top of the official `rmcp` SDK.
/// https://github.com/modelcontextprotocol/rust-sdk
pub struct RmcpClient {
Expand Down Expand Up @@ -105,16 +134,15 @@ impl RmcpClient {
}

pub fn new_streamable_http_client(url: String, bearer_token: Option<String>) -> Result<Self> {
let mut config = StreamableHttpClientTransportConfig::with_uri(url);
if let Some(token) = bearer_token {
config = config.auth_header(format!("Bearer {token}"));
}

let transport = StreamableHttpClientTransport::from_config(config);
let transport = build_streamable_http_transport(&url, bearer_token.as_deref());

Ok(Self {
state: Mutex::new(ClientState::Connecting {
transport: Some(PendingTransport::StreamableHttp(transport)),
transport: Some(PendingTransport::StreamableHttp {
transport,
url,
bearer_token,
}),
}),
})
}
Expand All @@ -126,52 +154,66 @@ impl RmcpClient {
params: InitializeRequestParams,
timeout: Option<Duration>,
) -> Result<InitializeResult> {
let transport = {
let pending_transport = {
let mut guard = self.state.lock().await;
match &mut *guard {
ClientState::Connecting { transport } => transport
.take()
.ok_or_else(|| anyhow!("client already initializing"))?,
ClientState::Ready { .. } => {
return Err(anyhow!("client already initialized"));
}
ClientState::Ready { .. } => return Err(anyhow!("client already initialized")),
}
};

let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?;
let client_handler = LoggingClientHandler::new(client_info);
let service_future = match transport {
let service = match pending_transport {
PendingTransport::ChildProcess(transport) => {
service::serve_client(client_handler.clone(), transport).boxed()
let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?;
let client_handler = LoggingClientHandler::new(client_info);
let service_future = service::serve_client(client_handler.clone(), transport).boxed();
await_handshake(service_future, timeout)
.await
.map_err(|err| annotate_phase_error(Phase::Initialize, err))?
}
PendingTransport::StreamableHttp(transport) => {
service::serve_client(client_handler, transport).boxed()
PendingTransport::StreamableHttp {
mut transport,
url,
bearer_token,
} => {
let mut attempt = 0;
loop {
let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?;
let client_handler = LoggingClientHandler::new(client_info);
let service_future = service::serve_client(client_handler.clone(), transport).boxed();
match await_handshake(service_future, timeout).await {
Ok(service) => break service,
Err(err) => {
let err = annotate_phase_error(Phase::Initialize, err);
if let Some(delay) = retry_delay_for_initialize(&err, attempt) {
attempt += 1;
time::sleep(delay).await;
transport = build_streamable_http_transport(&url, bearer_token.as_deref());
continue;
}
return Err(err);
}
}
}
}
};

let service = match timeout {
Some(duration) => match time::timeout(duration, service_future).await {
Ok(Ok(service)) => service,
Ok(Err(err)) => return Err(handshake_failed_error(err)),
Err(_) => return Err(handshake_timeout_error(duration)),
},
None => match service_future.await {
Ok(service) => service,
Err(err) => return Err(handshake_failed_error(err)),
},
};

let initialize_result_rmcp = service
.peer()
.peer_info()
.ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?;
.ok_or_else(|| annotate_phase_error(Phase::Initialize, anyhow!("handshake succeeded but server info was missing")))?;
let initialize_result: InitializeResult = convert_to_mcp(initialize_result_rmcp)?;

if initialize_result.protocol_version != MCP_SCHEMA_VERSION {
let reported_version = initialize_result.protocol_version.clone();
return Err(anyhow!(
"MCP server reported protocol version {reported_version}, but this client expects {}. Update either side so both speak the same schema.",
MCP_SCHEMA_VERSION
return Err(annotate_phase_error(
Phase::Initialize,
anyhow!(
"MCP server reported protocol version {reported_version}, but this client expects {}. Update either side so both speak the same schema.",
MCP_SCHEMA_VERSION
),
));
}

Expand All @@ -196,7 +238,9 @@ impl RmcpClient {
.transpose()?;

let fut = service.list_tools(rmcp_params);
let result = run_with_timeout(fut, timeout, "tools/list").await?;
let result = run_with_timeout(fut, timeout, "tools/list")
.await
.map_err(|err| annotate_phase_error(Phase::ListTools, err))?;
convert_to_mcp(result)
}

Expand All @@ -210,7 +254,9 @@ impl RmcpClient {
let params = CallToolRequestParams { arguments, name };
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
let fut = service.call_tool(rmcp_params);
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
let rmcp_result = run_with_timeout(fut, timeout, "tools/call")
.await
.map_err(|err| annotate_phase_error(Phase::CallTool, err))?;
convert_call_tool_result(rmcp_result)
}

Expand All @@ -229,6 +275,88 @@ impl RmcpClient {
}
}

async fn await_handshake<F, E>(
future: F,
timeout: Option<Duration>,
) -> Result<RunningService<RoleClient, LoggingClientHandler>>
where
F: Future<
Output = Result<
RunningService<RoleClient, LoggingClientHandler>,
E,
>,
>,
E: Into<anyhow::Error>,
{
if let Some(duration) = timeout {
match time::timeout(duration, future).await {
Ok(Ok(service)) => Ok(service),
Ok(Err(err)) => Err(handshake_failed_error(err)),
Err(_) => Err(handshake_timeout_error(duration)),
}
} else {
future.await.map_err(handshake_failed_error)
}
}

fn annotate_phase_error(phase: Phase, err: anyhow::Error) -> anyhow::Error {
err.context(format!("phase={}", phase.as_str()))
}

fn retry_delay_for_initialize(err: &anyhow::Error, attempt: usize) -> Option<Duration> {
if attempt >= INITIALIZE_MAX_RETRIES {
return None;
}

let retryable = err.chain().any(|source| {
if let Some(reqwest_err) = source.downcast_ref::<ReqwestError>() {
if reqwest_err.is_timeout() || reqwest_err.is_connect() {
return true;
}
}

if let Some(io_err) = source.downcast_ref::<io::Error>() {
if matches!(
io_err.kind(),
io::ErrorKind::TimedOut
| io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionReset
| io::ErrorKind::BrokenPipe
| io::ErrorKind::NotConnected
| io::ErrorKind::WouldBlock,
) {
return true;
}
}

source.downcast_ref::<HandshakeTimeoutError>().is_some()
});

if retryable {
Some(initialize_retry_delay(attempt))
} else {
None
}
}

fn initialize_retry_delay(attempt: usize) -> Duration {
let capped_attempt = attempt.min(4);
let multiplier = 1u64 << capped_attempt;
let delay = INITIALIZE_RETRY_BASE_DELAY_MS.saturating_mul(multiplier);
Duration::from_millis(delay.min(INITIALIZE_RETRY_MAX_DELAY_MS))
}

fn build_streamable_http_transport(
url: &str,
bearer_token: Option<&str>,
) -> StreamableHttpClientTransport<reqwest::Client> {
let mut config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
if let Some(token) = bearer_token {
config = config.auth_header(format!("Bearer {token}"));
}
StreamableHttpClientTransport::from_config(config)
}

fn handshake_failed_error(err: impl Into<anyhow::Error>) -> anyhow::Error {
let err = err.into();
anyhow!(
Comment on lines 360 to 362

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Propagate original handshake error so retries can detect transients

The new retry path in initialize relies on should_retry_initialize walking the error chain for ReqwestError or io::Error, but handshake_failed_error recreates the error with anyhow!(...) and drops the original source. After this wrapper runs the chain only contains the formatted string, so should_retry_initialize will never see the underlying transport error and the streamable HTTP initialize handshake will never be retried. Consider wrapping the incoming error instead of formatting it (e.g. err.context(...)) so its type survives downcasting.

Useful? React with 👍 / 👎.

Expand All @@ -237,14 +365,29 @@ fn handshake_failed_error(err: impl Into<anyhow::Error>) -> anyhow::Error {
}

fn handshake_timeout_error(duration: Duration) -> anyhow::Error {
anyhow!(
"timed out handshaking with MCP server after {duration:?} (expected MCP schema version {MCP_SCHEMA_VERSION})"
)
anyhow!(HandshakeTimeoutError(duration))
}

#[derive(Debug)]
struct HandshakeTimeoutError(Duration);

impl fmt::Display for HandshakeTimeoutError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"timed out awaiting MCP handshake after {:?}",
self.0
)
}
}

impl StdError for HandshakeTimeoutError {}

#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use std::time::Duration;

#[test]
fn mcp_schema_version_is_well_formed() {
Expand All @@ -257,4 +400,52 @@ mod tests {
);
assert!(parts.iter().all(|segment| !segment.trim().is_empty()));
}

#[test]
fn annotate_phase_error_adds_phase_label() {
let err = annotate_phase_error(Phase::ListTools, anyhow!("boom"));
let message = err.to_string();
assert_eq!(message, "phase=list_tools");
let sources: Vec<String> = err.chain().map(|source| source.to_string()).collect();
assert!(sources.iter().any(|s| s.contains("boom")), "sources: {sources:?}");
}

#[test]
fn retry_delay_for_initialize_detects_transient_errors() {
let timeout_err = annotate_phase_error(
Phase::Initialize,
anyhow!(io::Error::new(io::ErrorKind::TimedOut, "timed out")),
);
assert_eq!(
retry_delay_for_initialize(&timeout_err, 0),
Some(Duration::from_millis(INITIALIZE_RETRY_BASE_DELAY_MS))
);
assert_eq!(retry_delay_for_initialize(&timeout_err, INITIALIZE_MAX_RETRIES), None);

let mismatch_err = annotate_phase_error(Phase::Initialize, anyhow!("protocol mismatch"));
assert_eq!(retry_delay_for_initialize(&mismatch_err, 0), None);
}

#[test]
fn retry_delay_handles_handshake_timeout() {
let err = annotate_phase_error(
Phase::Initialize,
handshake_timeout_error(Duration::from_secs(1)),
);
assert!(retry_delay_for_initialize(&err, 0).is_some());
}

#[test]
fn initialize_retry_delay_exponential_and_capped() {
let first = initialize_retry_delay(0);
let second = initialize_retry_delay(1);
let capped = initialize_retry_delay(10);

assert_eq!(first, Duration::from_millis(INITIALIZE_RETRY_BASE_DELAY_MS));
assert_eq!(second, Duration::from_millis(INITIALIZE_RETRY_BASE_DELAY_MS * 2));
assert_eq!(
capped,
Duration::from_millis(INITIALIZE_RETRY_MAX_DELAY_MS)
);
}
}
Loading