Skip to content

Commit

Permalink
Fixing read race condition during pubsub (redis#1737)
Browse files Browse the repository at this point in the history
  • Loading branch information
barshaul authored Dec 23, 2021
1 parent ddc51c4 commit d6cb997
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 15 deletions.
74 changes: 68 additions & 6 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,18 +1288,17 @@ def __init__(
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
self.connection = None
self.subscribed_event = threading.Event()
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
if self.encoder.decode_responses:
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
else:
self.health_check_response = [
b"pong",
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
]
self.health_check_response = [b"pong", self.health_check_response_b]
self.reset()

def __enter__(self):
Expand All @@ -1324,9 +1323,11 @@ def reset(self):
self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
self.health_check_response_counter = 0
self.pending_unsubscribe_channels = set()
self.patterns = {}
self.pending_unsubscribe_patterns = set()
self.subscribed_event.clear()

def close(self):
self.reset()
Expand All @@ -1352,7 +1353,7 @@ def on_connect(self, connection):
@property
def subscribed(self):
"Indicates if there are subscriptions to any channels or patterns"
return bool(self.channels or self.patterns)
return self.subscribed_event.is_set()

def execute_command(self, *args):
"Execute a publish/subscribe command"
Expand All @@ -1370,8 +1371,28 @@ def execute_command(self, *args):
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
self.clean_health_check_responses()
self._execute(connection, connection.send_command, *args, **kwargs)

def clean_health_check_responses(self):
"""
If any health check responses are present, clean them
"""
ttl = 10
conn = self.connection
while self.health_check_response_counter > 0 and ttl > 0:
if self._execute(conn, conn.can_read, timeout=conn.socket_timeout):
response = self._execute(conn, conn.read_response)
if self.is_health_check_response(response):
self.health_check_response_counter -= 1
else:
raise PubSubError(
"A non health check response was cleaned by "
"execute_command: {0}".format(response)
)
ttl -= 1

def _disconnect_raise_connect(self, conn, error):
"""
Close the connection and raise an exception
Expand Down Expand Up @@ -1411,11 +1432,23 @@ def parse_response(self, block=True, timeout=0):
return None
response = self._execute(conn, conn.read_response)

if conn.health_check_interval and response == self.health_check_response:
if self.is_health_check_response(response):
# ignore the health check message as user might not expect it
self.health_check_response_counter -= 1
return None
return response

def is_health_check_response(self, response):
"""
Check if the response is a health check response.
If there are no subscriptions redis responds to PING command with a
bulk response, instead of a multi-bulk with "pong" and the response.
"""
return response in [
self.health_check_response, # If there was a subscription
self.health_check_response_b, # If there wasn't
]

def check_health(self):
conn = self.connection
if conn is None:
Expand All @@ -1426,6 +1459,7 @@ def check_health(self):

if conn.health_check_interval and time.time() > conn.next_health_check:
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
self.health_check_response_counter += 1

def _normalize_keys(self, data):
"""
Expand Down Expand Up @@ -1455,6 +1489,11 @@ def psubscribe(self, *args, **kwargs):
# for the reconnection.
new_patterns = self._normalize_keys(new_patterns)
self.patterns.update(new_patterns)
if not self.subscribed:
# Set the subscribed_event flag to True
self.subscribed_event.set()
# Clear the health check counter
self.health_check_response_counter = 0
self.pending_unsubscribe_patterns.difference_update(new_patterns)
return ret_val

Expand Down Expand Up @@ -1489,6 +1528,11 @@ def subscribe(self, *args, **kwargs):
# for the reconnection.
new_channels = self._normalize_keys(new_channels)
self.channels.update(new_channels)
if not self.subscribed:
# Set the subscribed_event flag to True
self.subscribed_event.set()
# Clear the health check counter
self.health_check_response_counter = 0
self.pending_unsubscribe_channels.difference_update(new_channels)
return ret_val

Expand Down Expand Up @@ -1520,6 +1564,20 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0):
before returning. Timeout should be specified as a floating point
number.
"""
if not self.subscribed:
# Wait for subscription
start_time = time.time()
if self.subscribed_event.wait(timeout) is True:
# The connection was subscribed during the timeout time frame.
# The timeout should be adjusted based on the time spent
# waiting for the subscription
time_spent = time.time() - start_time
timeout = max(0.0, timeout - time_spent)
else:
# The connection isn't subscribed to any channels or patterns,
# so no messages are available
return None

response = self.parse_response(block=False, timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
Expand Down Expand Up @@ -1575,6 +1633,10 @@ def handle_message(self, response, ignore_subscribe_messages=False):
if channel in self.pending_unsubscribe_channels:
self.pending_unsubscribe_channels.remove(channel)
self.channels.pop(channel, None)
if not self.channels and not self.patterns:
# There are no subscriptions anymore, set subscribed_event flag
# to false
self.subscribed_event.clear()

if message_type in self.PUBLISH_MESSAGE_TYPES:
# if there's a message handler, invoke it
Expand Down
43 changes: 34 additions & 9 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading
import time
from unittest import mock
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -348,15 +349,6 @@ def test_unicode_pattern_message_handler(self, r):
"pmessage", channel, "test message", pattern=pattern
)

def test_get_message_without_subscribe(self, r):
p = r.pubsub()
with pytest.raises(RuntimeError) as info:
p.get_message()
expect = (
"connection not set: " "did you forget to call subscribe() or psubscribe()?"
)
assert expect in info.exconly()


class TestPubSubAutoDecoding:
"These tests only validate that we get unicode values back"
Expand Down Expand Up @@ -549,6 +541,39 @@ def test_get_message_with_timeout_returns_none(self, r):
assert wait_for_message(p) == make_message("subscribe", "foo", 1)
assert p.get_message(timeout=0.01) is None

def test_get_message_not_subscribed_return_none(self, r):
p = r.pubsub()
assert p.subscribed is False
assert p.get_message() is None
assert p.get_message(timeout=0.1) is None
with patch.object(threading.Event, "wait") as mock:
mock.return_value = False
assert p.get_message(timeout=0.01) is None
assert mock.called

def test_get_message_subscribe_during_waiting(self, r):
p = r.pubsub()

def poll(ps, expected_res):
assert ps.get_message() is None
message = ps.get_message(timeout=1)
assert message == expected_res

subscribe_response = make_message("subscribe", "foo", 1)
poller = threading.Thread(target=poll, args=(p, subscribe_response))
poller.start()
time.sleep(0.2)
p.subscribe("foo")
poller.join()

def test_get_message_wait_for_subscription_not_being_called(self, r):
p = r.pubsub()
p.subscribe("foo")
with patch.object(threading.Event, "wait") as mock:
assert p.subscribed is True
assert wait_for_message(p) == make_message("subscribe", "foo", 1)
assert mock.called is False


class TestPubSubWorkerThread:
@pytest.mark.skipif(
Expand Down

0 comments on commit d6cb997

Please sign in to comment.