Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate JSON-RPC errors through the Rust subscription #929

Merged
merged 11 commits into from
Jul 19, 2021
Next Next commit
Propagate JSON-RPC errors through the Rust subscription
  • Loading branch information
romac committed Jul 15, 2021
commit 99d6f30e6bb930390ced4142713c4850c7ce3a90
6 changes: 3 additions & 3 deletions rpc/src/client/transport/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl MockClientDriver {
DriverCommand::Unsubscribe { query, result_tx } => {
self.unsubscribe(query, result_tx);
}
DriverCommand::Publish(event) => self.publish(event.as_ref()),
DriverCommand::Publish(event) => self.publish(*event),
DriverCommand::Terminate => return Ok(()),
}
}
Expand All @@ -179,8 +179,8 @@ impl MockClientDriver {
result_tx.send(Ok(())).unwrap();
}

fn publish(&mut self, event: &Event) {
self.router.publish(event);
fn publish(&mut self, event: Event) {
self.router.publish_event(event);
}
}

Expand Down
57 changes: 41 additions & 16 deletions rpc/src/client/transport/router.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
//! Event routing for subscriptions.

use crate::client::subscription::SubscriptionTx;
use crate::event::Event;
use std::borrow::BorrowMut;
use std::collections::{HashMap, HashSet};

use tracing::debug;

use crate::client::subscription::SubscriptionTx;
use crate::error::Error;
use crate::event::Event;

/// Provides a mechanism for tracking [`Subscription`]s and routing [`Event`]s
/// to those subscriptions.
///
Expand All @@ -17,36 +19,54 @@ pub struct SubscriptionRouter {
// their result channels. Used for publishing events relating to a specific
// query.
subscriptions: HashMap<String, HashMap<String, SubscriptionTx>>,

// A map of subscription ids to their queries
sub_to_query: HashMap<String, String>,
}

impl SubscriptionRouter {
pub fn publish_error(&mut self, id: &String, err: Error) -> PublishResult {
if let Some(query) = self.sub_to_query.get(id) {
let query = query.clone();
self.publish(query, Err(err))
} else {
PublishResult::NoSubscribers
}
}

pub fn publish_event(&mut self, ev: Event) -> PublishResult {
self.publish(ev.query.clone(), Ok(ev))
}

/// Publishes the given event to all of the subscriptions to which the
/// event is relevant. At present, it matches purely based on the query
/// associated with the event, and only queries that exactly match that of
/// the event's.
pub fn publish(&mut self, ev: &Event) -> PublishResult {
let subs_for_query = match self.subscriptions.get_mut(&ev.query) {
/// event is relevant, based on the given query.
pub fn publish(&mut self, query: String, ev: Result<Event, Error>) -> PublishResult {
let subs_for_query = match self.subscriptions.get_mut(&query) {
Some(s) => s,
None => return PublishResult::NoSubscribers,
};

// We assume here that any failure to publish an event is an indication
// that the receiver end of the channel has been dropped, which allows
// us to safely stop tracking the subscription.
let mut disconnected = HashSet::new();
for (id, event_tx) in subs_for_query.borrow_mut() {
if let Err(e) = event_tx.send(Ok(ev.clone())) {
for (id, event_tx) in subs_for_query.iter_mut() {
if let Err(e) = event_tx.send(ev.clone()) {
disconnected.insert(id.clone());
debug!(
"Automatically disconnecting subscription with ID {} for query \"{}\" due to failure to publish to it: {}",
id, ev.query, e
id, query, e
);
}
}

for id in disconnected {
subs_for_query.remove(&id);
self.sub_to_query.remove(&id);
}

if subs_for_query.is_empty() {
PublishResult::AllDisconnected
PublishResult::AllDisconnected(query)
} else {
PublishResult::Success
}
Expand All @@ -63,7 +83,11 @@ impl SubscriptionRouter {
self.subscriptions.get_mut(&query).unwrap()
}
};
subs_for_query.insert(id.to_string(), tx);

let id = id.to_string();

subs_for_query.insert(id.clone(), tx);
self.sub_to_query.insert(id, query);
}

/// Removes all the subscriptions relating to the given query.
Expand All @@ -90,6 +114,7 @@ impl Default for SubscriptionRouter {
fn default() -> Self {
Self {
subscriptions: HashMap::new(),
sub_to_query: HashMap::new(),
}
}
}
Expand All @@ -98,7 +123,7 @@ impl Default for SubscriptionRouter {
pub enum PublishResult {
Success,
NoSubscribers,
AllDisconnected,
AllDisconnected(String),
}

#[cfg(test)]
Expand Down Expand Up @@ -160,7 +185,7 @@ mod test {

let mut ev = read_event("event_new_block_1").await;
ev.query = "query1".into();
router.publish(&ev);
router.publish_event(ev.clone());

let subs1_ev = must_recv(&mut subs1_event_rx, 500).await.unwrap();
let subs2_ev = must_recv(&mut subs2_event_rx, 500).await.unwrap();
Expand All @@ -169,7 +194,7 @@ mod test {
assert_eq!(ev, subs2_ev);

ev.query = "query2".into();
router.publish(&ev);
router.publish_event(ev.clone());

must_not_recv(&mut subs1_event_rx, 50).await;
must_not_recv(&mut subs2_event_rx, 50).await;
Expand Down
45 changes: 39 additions & 6 deletions rpc/src/client/transport/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ struct GenericJsonResponse(serde_json::Value);

impl Response for GenericJsonResponse {}

type SubscriptionId = String;

/// Drives the WebSocket connection for a `WebSocketClient` instance.
///
/// This is the primary component responsible for transport-level interaction
Expand All @@ -504,7 +506,7 @@ pub struct WebSocketClientDriver {
cmd_rx: ChannelRx<DriverCommand>,
// Commands we've received but have not yet completed, indexed by their ID.
// A Terminate command is executed immediately.
pending_commands: HashMap<String, DriverCommand>,
pending_commands: HashMap<SubscriptionId, DriverCommand>,
}

impl WebSocketClientDriver {
Expand Down Expand Up @@ -650,38 +652,69 @@ impl WebSocketClientDriver {
return Ok(());
}

let wrapper = match serde_json::from_str::<response::Wrapper<GenericJsonResponse>>(&msg) {
let wrapper: response::Wrapper<GenericJsonResponse> = match serde_json::from_str(&msg) {
Ok(w) => w,
Err(e) => {
error!(
"Failed to deserialize incoming message as a JSON-RPC message: {}",
e
);

debug!("JSON-RPC message: {}", msg);

return Ok(());
}
};

debug!("Generic JSON-RPC message: {:?}", wrapper);

let id = wrapper.id().to_string();

if let Some(e) = wrapper.into_error() {
self.publish_error(&id, e).await;
}

if let Some(pending_cmd) = self.pending_commands.remove(&id) {
return self.respond_to_pending_command(pending_cmd, msg).await;
self.respond_to_pending_command(pending_cmd, msg).await?;
};

// We ignore incoming messages whose ID we don't recognize (could be
// relating to a fire-and-forget unsubscribe request - see the
// publish_event() method below).
Ok(())
}

async fn publish_error(&mut self, id: &String, err: Error) {
if let PublishResult::AllDisconnected(query) = self.router.publish_error(id, err) {
debug!(
"All subscribers for query \"{}\" have disconnected. Unsubscribing from query...",
query
);

// If all subscribers have disconnected for this query, we need to
// unsubscribe from it. We issue a fire-and-forget unsubscribe
// message.
if let Err(e) = self
.send_request(Wrapper::new(unsubscribe::Request::new(query)))
.await
{
error!("Failed to send unsubscribe request: {}", e);
}
}
}

async fn publish_event(&mut self, ev: Event) {
if let PublishResult::AllDisconnected = self.router.publish(&ev) {
if let PublishResult::AllDisconnected(query) = self.router.publish_event(ev) {
debug!(
"All subscribers for query \"{}\" have disconnected. Unsubscribing from query...",
ev.query
query
);

// If all subscribers have disconnected for this query, we need to
// unsubscribe from it. We issue a fire-and-forget unsubscribe
// message.
if let Err(e) = self
.send_request(Wrapper::new(unsubscribe::Request::new(ev.query.clone())))
.send_request(Wrapper::new(unsubscribe::Request::new(query)))
.await
{
error!("Failed to send unsubscribe request: {}", e);
Expand Down
6 changes: 6 additions & 0 deletions rpc/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ where
&self.id
}

/// Convert this wrapper into the underlying error, if any
#[allow(dead_code)]
pub fn into_error(self) -> Option<Error> {
self.error
}

/// Convert this wrapper into a result type
pub fn into_result(self) -> Result<R, Error> {
// Ensure we're using a supported RPC version
Expand Down