Skip to content

Commit

Permalink
fix does not allow SUBSCRIBE to be sent to server if already subscribed
Browse files Browse the repository at this point in the history
  • Loading branch information
mcatanzariti committed May 25, 2024
1 parent 527e6ec commit 37336c4
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 111 deletions.
246 changes: 135 additions & 111 deletions src/network/network_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl NetworkHandler {
loop {
select! {
msg = self.msg_receiver.next().fuse() => {
if !self.handle_message(msg).await { break; }
if !self.try_handle_message(msg).await { break; }
} ,
result = self.connection.read().fuse() => {
if !self.handle_result(result).await { break; }
Expand All @@ -162,118 +162,12 @@ impl NetworkHandler {
Ok(())
}

async fn handle_message(&mut self, mut msg: Option<Message>) -> bool {
async fn try_handle_message(&mut self, mut msg: Option<Message>) -> bool {
let is_channel_closed: bool;

loop {
if let Some(mut msg) = msg {
trace!(
"[{}][{:?}] Will handle message: {msg:?}",
self.tag,
self.status
);
let pub_sub_senders = msg.pub_sub_senders.take();
if let Some(pub_sub_senders) = pub_sub_senders {
let subscription_type = match &msg.commands {
Commands::Single(command, _) => match command.name {
"SUBSCRIBE" => SubscriptionType::Channel,
"PSUBSCRIBE" => SubscriptionType::Pattern,
"SSUBSCRIBE" => SubscriptionType::ShardChannel,
_ => unreachable!(),
},
_ => unreachable!(),
};

let num_pending_subscriptions = pub_sub_senders.len();
let pending_subscriptions = pub_sub_senders.into_iter().enumerate().map(
|(index, (channel_or_pattern, sender))| PendingSubscription {
channel_or_pattern,
subscription_type,
sender,
more_to_come: index < num_pending_subscriptions - 1,
},
);

self.pending_subscriptions.extend(pending_subscriptions);
}

let push_sender = msg.push_sender.take();
if let Some(push_sender) = push_sender {
debug!("[{}] Registering push_sender", self.tag);
self.push_sender = Some(push_sender);
}

match &self.status {
Status::Connected => {
for command in &msg.commands {
match command.name {
"SUBSCRIBE" | "PSUBSCRIBE" | "SSUBSCRIBE" => {
self.status = Status::Subscribing;
}
"MONITOR" => {
self.status = Status::EnteringMonitor;
}
_ => (),
}
}
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Status::Subscribing => {
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Status::Subscribed => {
for command in &msg.commands {
let subscription_type = match command.name {
"UNSUBSCRIBE" => Some(SubscriptionType::Channel),
"PUNSUBSCRIBE" => Some(SubscriptionType::Pattern),
"SUNSUBSCRIBE" => Some(SubscriptionType::ShardChannel),
_ => None,
};
if let Some(subscription_type) = subscription_type {
self.pending_unsubscriptions.push_back(
command
.args
.into_iter()
.map(|a| (a.to_vec(), subscription_type))
.collect(),
);
}
}
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Status::Disconnected => {
if msg.retry_on_error {
debug!(
"[{}] network disconnected, queuing command: {:?}",
self.tag, msg.commands
);
self.messages_to_send.push_back(MessageToSend::new(msg));
} else {
debug!(
"[{}] network disconnected, ending command in error: {:?}",
self.tag, msg.commands
);
msg.commands.send_error(
&self.tag,
Error::Client("Disconnected from server".to_string()),
);
}
}
Status::EnteringMonitor => {
self.messages_to_send.push_back(MessageToSend::new(msg))
}
Status::Monitor => {
for command in &msg.commands {
if command.name == "RESET" {
self.status = Status::LeavingMonitor;
}
}
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Status::LeavingMonitor => {
self.messages_to_send.push_back(MessageToSend::new(msg));
}
}
if let Some(msg) = msg {
self.handle_message(msg).await;
} else {
is_channel_closed = true;
break;
Expand All @@ -296,6 +190,136 @@ impl NetworkHandler {
!is_channel_closed
}

async fn handle_message(&mut self, mut msg: Message) {
trace!(
"[{}][{:?}] Will handle message: {msg:?}",
self.tag,
self.status
);
let pub_sub_senders = msg.pub_sub_senders.take();
if let Some(pub_sub_senders) = pub_sub_senders {
let subscription_type = match &msg.commands {
Commands::Single(command, _) => match command.name {
"SUBSCRIBE" => SubscriptionType::Channel,
"PSUBSCRIBE" => SubscriptionType::Pattern,
"SSUBSCRIBE" => SubscriptionType::ShardChannel,
_ => unreachable!(),
},
_ => unreachable!(),
};

for (channel_or_pattern, _sender) in pub_sub_senders.iter() {
if self.subscriptions.contains_key(channel_or_pattern) {
debug!(
"[{}][{:?}] There is already a subscription on channel `{}`",
self.tag,
self.status,
String::from_utf8_lossy(channel_or_pattern)
);
msg.commands.send_error(
&self.tag,
Error::Client(
format!(
"There is already a subscription on channel `{}`",
String::from_utf8_lossy(channel_or_pattern)
)
.to_string(),
),
);
return;
}
}

let num_pending_subscriptions = pub_sub_senders.len();
let pending_subscriptions = pub_sub_senders.into_iter().enumerate().map(
|(index, (channel_or_pattern, sender))| PendingSubscription {
channel_or_pattern,
subscription_type,
sender,
more_to_come: index < num_pending_subscriptions - 1,
},
);

self.pending_subscriptions.extend(pending_subscriptions);
}

let push_sender = msg.push_sender.take();
if let Some(push_sender) = push_sender {
debug!("[{}] Registering push_sender", self.tag);
self.push_sender = Some(push_sender);
}

match &self.status {
Status::Connected => {
for command in &msg.commands {
match command.name {
"SUBSCRIBE" | "PSUBSCRIBE" | "SSUBSCRIBE" => {
self.status = Status::Subscribing;
}
"MONITOR" => {
self.status = Status::EnteringMonitor;
}
_ => (),
}
}
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Status::Subscribing => {
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Status::Subscribed => {
for command in &msg.commands {
let subscription_type = match command.name {
"UNSUBSCRIBE" => Some(SubscriptionType::Channel),
"PUNSUBSCRIBE" => Some(SubscriptionType::Pattern),
"SUNSUBSCRIBE" => Some(SubscriptionType::ShardChannel),
_ => None,
};
if let Some(subscription_type) = subscription_type {
self.pending_unsubscriptions.push_back(
command
.args
.into_iter()
.map(|a| (a.to_vec(), subscription_type))
.collect(),
);
}
}
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Status::Disconnected => {
if msg.retry_on_error {
debug!(
"[{}] network disconnected, queuing command: {:?}",
self.tag, msg.commands
);
self.messages_to_send.push_back(MessageToSend::new(msg));
} else {
debug!(
"[{}] network disconnected, ending command in error: {:?}",
self.tag, msg.commands
);
msg.commands.send_error(
&self.tag,
Error::Client("Disconnected from server".to_string()),
);
}
}
Status::EnteringMonitor => self.messages_to_send.push_back(MessageToSend::new(msg)),
Status::Monitor => {
for command in &msg.commands {
if command.name == "RESET" {
self.status = Status::LeavingMonitor;
}
}
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Status::LeavingMonitor => {
self.messages_to_send.push_back(MessageToSend::new(msg));
}
}
}

async fn send_messages(&mut self) {
if log_enabled!(Level::Debug) {
let num_commands = self
Expand Down Expand Up @@ -755,7 +779,7 @@ impl NetworkHandler {
let delay = end.duration_since(Instant::now());
let result = timeout(delay, self.msg_receiver.next().fuse()).await;
if let Ok(msg) = result {
if !self.handle_message(msg).await {
if !self.try_handle_message(msg).await {
return false;
}
} else {
Expand Down
1 change: 1 addition & 0 deletions src/tests/pub_sub_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ async fn subscribe_multiple_times_to_the_same_channel() -> Result<()> {
let mut pub_sub_stream = pub_sub_client.subscribe("mychannel").await?;
assert!(pub_sub_stream.subscribe("mychannel").await.is_err());
assert!(pub_sub_client.subscribe("mychannel").await.is_err());
regular_client.publish("mychannel", "mymessage").await?;

pub_sub_stream.psubscribe("pattern").await?;
assert!(pub_sub_stream.psubscribe("pattern").await.is_err());
Expand Down

0 comments on commit 37336c4

Please sign in to comment.