Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@


[workspace]
members = ["crates/rmcp", "crates/rmcp-macros", "examples/*"]
members = ["crates/rmcp", "crates/rmcp-macros"]
resolver = "2"

[workspace.dependencies]
Expand Down
1 change: 1 addition & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ documentation = "https://docs.rs/rmcp"
all-features = true

[dependencies]
async-stream = "0.3"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
thiserror = "2"
Expand Down
Binary file added crates/rmcp/src/.DS_Store
Binary file not shown.
40 changes: 35 additions & 5 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ use crate::{
},
transport::IntoTransport,
};

#[cfg(feature = "transport-sse-server")]
use axum::http::Extensions as AxumExtensions;

#[cfg(not(feature = "transport-sse-server"))]
#[derive(Debug, Clone, Default)]
pub struct AxumExtensions;

pub trait ProvidesAxiumExtensions {
fn get_extensions(&self) -> &AxumExtensions;
fn get_workspace_id(&self) -> String;
}

#[cfg(feature = "client")]
mod client;
#[cfg(feature = "client")]
Expand Down Expand Up @@ -109,7 +122,7 @@ pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
transport: T,
) -> impl Future<Output = Result<RunningService<R, Self>, E>> + Send
where
T: IntoTransport<R, E, A>,
T: IntoTransport<R, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
Self: Sized,
{
Expand All @@ -121,7 +134,7 @@ pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
ct: CancellationToken,
) -> impl Future<Output = Result<RunningService<R, Self>, E>> + Send
where
T: IntoTransport<R, E, A>,
T: IntoTransport<R, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
Self: Sized;
}
Expand Down Expand Up @@ -474,6 +487,8 @@ pub struct RequestContext<R: ServiceRole> {
pub extensions: Extensions,
/// An interface to fetch the remote client or server
pub peer: Peer<R>,
pub req_extensions: AxumExtensions,
pub workspace_id: String,
}

/// Use this function to skip initialization process
Expand All @@ -485,7 +500,7 @@ pub async fn serve_directly<R, S, T, E, A>(
where
R: ServiceRole,
S: Service<R>,
T: IntoTransport<R, E, A>,
T: IntoTransport<R, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + Send + Sync + 'static,
{
serve_directly_with_ct(service, transport, peer_info, Default::default()).await
Expand All @@ -501,11 +516,22 @@ pub async fn serve_directly_with_ct<R, S, T, E, A>(
where
R: ServiceRole,
S: Service<R>,
T: IntoTransport<R, E, A>,
T: IntoTransport<R, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + Send + Sync + 'static,
{
let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info);
serve_inner(service, transport, peer, peer_rx, ct).await
let req_extensions = transport.get_extensions().clone();
let workspace_id = transport.get_workspace_id();
serve_inner(
service,
transport,
peer,
peer_rx,
ct,
req_extensions,
workspace_id,
)
.await
}

#[instrument(skip_all)]
Expand All @@ -515,6 +541,8 @@ async fn serve_inner<R, S, T, E, A>(
peer: Peer<R>,
mut peer_rx: tokio::sync::mpsc::Receiver<PeerSinkMessage<R>>,
ct: CancellationToken,
req_extensions: AxumExtensions,
workspace_id: String,
) -> Result<RunningService<R, S>, E>
where
R: ServiceRole,
Expand Down Expand Up @@ -669,6 +697,8 @@ where
peer: peer.clone(),
meta: request.get_meta().clone(),
extensions: request.extensions().clone(),
req_extensions: req_extensions.clone(),
workspace_id: workspace_id.clone(),
};
tokio::spawn(async move {
let result = service.handle_request(request, context).await;
Expand Down
20 changes: 16 additions & 4 deletions crates/rmcp/src/service/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
ct: CancellationToken,
) -> impl Future<Output = Result<RunningService<RoleClient, Self>, E>> + Send
where
T: IntoTransport<RoleClient, E, A>,
T: IntoTransport<RoleClient, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
Self: Sized,
{
Expand All @@ -107,7 +107,7 @@ pub async fn serve_client<S, T, E, A>(
) -> Result<RunningService<RoleClient, S>, E>
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
T: IntoTransport<RoleClient, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
{
serve_client_with_ct(service, transport, Default::default()).await
Expand All @@ -120,9 +120,11 @@ pub async fn serve_client_with_ct<S, T, E, A>(
) -> Result<RunningService<RoleClient, S>, E>
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
T: IntoTransport<RoleClient, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
{
let req_extensions = transport.get_extensions().clone();
let workspace_id = transport.get_workspace_id();
let (sink, stream) = transport.into_transport();
let mut sink = Box::pin(sink);
let mut stream = Box::pin(stream);
Expand Down Expand Up @@ -175,7 +177,17 @@ where
);
sink.send(notification).await?;
let (peer, peer_rx) = Peer::new(id_provider, initialize_result);
serve_inner(service, (sink, stream), peer, peer_rx, ct).await

serve_inner(
service,
(sink, stream),
peer,
peer_rx,
ct,
req_extensions,
workspace_id,
)
.await
}

macro_rules! method {
Expand Down
21 changes: 17 additions & 4 deletions crates/rmcp/src/service/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
ct: CancellationToken,
) -> impl Future<Output = Result<RunningService<RoleServer, Self>, E>> + Send
where
T: IntoTransport<RoleServer, E, A>,
T: IntoTransport<RoleServer, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
Self: Sized,
{
Expand All @@ -74,7 +74,7 @@ pub async fn serve_server<S, T, E, A>(
) -> Result<RunningService<RoleServer, S>, E>
where
S: Service<RoleServer>,
T: IntoTransport<RoleServer, E, A>,
T: IntoTransport<RoleServer, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
{
serve_server_with_ct(service, transport, CancellationToken::new()).await
Expand Down Expand Up @@ -129,9 +129,11 @@ pub async fn serve_server_with_ct<S, T, E, A>(
) -> Result<RunningService<RoleServer, S>, E>
where
S: Service<RoleServer>,
T: IntoTransport<RoleServer, E, A>,
T: IntoTransport<RoleServer, E, A> + ProvidesAxiumExtensions,
E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
{
let req_extensions = transport.get_extensions().clone();
let workspace_id = transport.get_workspace_id();
let (sink, stream) = transport.into_transport();
let mut sink = Box::pin(sink);
let mut stream = Box::pin(stream);
Expand Down Expand Up @@ -162,6 +164,8 @@ where
meta: request.get_meta().clone(),
extensions: request.extensions().clone(),
peer: peer.clone(),
req_extensions: req_extensions.clone(),
workspace_id: workspace_id.clone(),
};
// Send initialize response
let init_response = service.handle_request(request.clone(), context).await;
Expand Down Expand Up @@ -207,7 +211,16 @@ where
};
let _ = service.handle_notification(notification).await;
// Continue processing service
serve_inner(service, (sink, stream), peer, peer_rx, ct).await
serve_inner(
service,
(sink, stream),
peer,
peer_rx,
ct,
req_extensions,
workspace_id,
)
.await
}

macro_rules! method {
Expand Down
Loading