Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rmq channel leak #189

Merged
merged 2 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions servicelayer/taskqueue.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, List
import json
import time
import threading
Expand All @@ -17,11 +17,13 @@
from random import randrange

import pika.spec
from pika.adapters.blocking_connection import BlockingChannel
from prometheus_client import start_http_server

from structlog.contextvars import clear_contextvars, bind_contextvars
import pika
from banal import ensure_list
from redis import Redis

from servicelayer.cache import get_redis, make_key
from servicelayer.util import pack_now, unpack_int
Expand Down Expand Up @@ -434,7 +436,7 @@ class Worker(ABC):
def __init__(
self,
queues,
conn=None,
conn: Redis = None,
num_threads=settings.WORKER_THREADS,
version=None,
prefetch_count_mapping=defaultdict(lambda: 1),
Expand Down Expand Up @@ -491,19 +493,18 @@ def on_message(self, channel, method, properties, body, args):
We have to make sure it doesn't block for long to ensure that RabbitMQ
heartbeats are not interrupted.
"""
connection = args[0]
task = get_task(body, method.delivery_tag)
# the task needs to be acknowledged in the same channel that it was
# received. So store the channel. This is useful when executing batched
# indexing tasks since they are acknowledged late.
task._channel = channel
self.local_queue.put((task, channel, connection))
self.local_queue.put((task, channel))

def process_blocking(self):
"""Blocking worker thread - executes tasks from a queue and periodic tasks"""
while True:
try:
(task, channel, connection) = self.local_queue.get(timeout=TIMEOUT)
(task, channel) = self.local_queue.get(timeout=TIMEOUT)
apply_task_context(task, v=self.version)
success, retry = self.handle(task, channel)
log.debug(
Expand All @@ -514,7 +515,7 @@ def process_blocking(self):
cb = functools.partial(self.ack_message, task, channel)
else:
cb = functools.partial(self.nack_message, task, channel, retry)
connection.add_callback_threadsafe(cb)
channel.connection.add_callback_threadsafe(cb)
except Empty:
pass
finally:
Expand All @@ -523,8 +524,7 @@ def process_blocking(self):

def process_nonblocking(self):
"""Non-blocking worker is used for tests only."""
connection = get_rabbitmq_connection()
channel = connection.channel()
channel = get_rabbitmq_channel()
queue_active = {queue: True for queue in self.queues}
while True:
for queue in self.queues:
Expand Down Expand Up @@ -632,7 +632,7 @@ def periodic(self):
"""Periodic tasks to run."""
pass

def ack_message(self, task, channel):
def ack_message(self, task, channel, multiple=False):
"""Acknowledge a task after execution.

RabbitMQ requires that the channel used for receiving the message must be used
Expand All @@ -653,7 +653,7 @@ def ack_message(self, task, channel):
# Sync state to redis
dataset.mark_done(task)
if channel.is_open:
channel.basic_ack(task.delivery_tag)
channel.basic_ack(task.delivery_tag, multiple=multiple)
clear_contextvars()

def nack_message(self, task, channel, requeue=True):
Expand Down Expand Up @@ -697,9 +697,8 @@ def process():

log.info(f"Worker has {self.num_threads} worker threads.")

connection = get_rabbitmq_connection()
channel = connection.channel()
on_message_callback = functools.partial(self.on_message, args=(connection,))
channel = get_rabbitmq_channel()
on_message_callback = functools.partial(self.on_message, args=(channel,))

for queue in self.queues:
declare_rabbitmq_queue(
Expand All @@ -709,13 +708,14 @@ def process():
channel.start_consuming()


def get_rabbitmq_connection():
def get_rabbitmq_channel() -> BlockingChannel:
for attempt in service_retries():
try:
if (
not hasattr(local, "connection")
or not local.connection
or not local.connection.is_open
or not local.channel
or attempt > 0
):
log.debug(
Expand All @@ -735,16 +735,17 @@ def get_rabbitmq_connection():
)
)
local.connection = connection
local.channel = connection.channel()

# Check that the connection is alive
result = local.connection.channel().exchange_declare(
result = local.channel.exchange_declare(
exchange="amq.topic",
exchange_type=pika.exchange_type.ExchangeType.topic,
passive=True,
)
assert isinstance(result.method, pika.spec.Exchange.DeclareOk)

return local.connection
return local.channel

except (
pika.exceptions.AMQPConnectionError,
Expand All @@ -764,6 +765,7 @@ def get_rabbitmq_connection():
f"Attempt: {attempt}/{service_retries().stop}"
)
local.connection = None
local.channel = None

backoff(failures=attempt)
raise RuntimeError("Could not connect to RabbitMQ")
Expand Down Expand Up @@ -804,7 +806,13 @@ def dataset_from_collection(collection):


def queue_task(
rmq_conn, redis_conn, collection_id, stage, job_id=None, context=None, **payload
rmq_channel: BlockingChannel,
redis_conn,
collection_id: int,
stage: str,
job_id=None,
context=None,
**payload,
):
task_id = uuid.uuid4().hex
priority = get_priority(collection_id, redis_conn)
Expand All @@ -818,9 +826,8 @@ def queue_task(
"priority": priority,
}
try:
channel = rmq_conn.channel()
channel.confirm_delivery()
channel.basic_publish(
rmq_channel.confirm_delivery()
rmq_channel.basic_publish(
exchange="",
routing_key=stage,
body=json.dumps(body),
Expand All @@ -832,23 +839,15 @@ def queue_task(
dataset.add_task(task_id, stage)
except (pika.exceptions.UnroutableError, pika.exceptions.AMQPConnectionError):
log.exception("Error while queuing task")
finally:
try:
if channel:
channel.close()
except pika.exceptions.ChannelWrongStateError:
log.exception("Failed to explicitly close RabbitMQ channel.")


def flush_queues(rmq_conn, redis_conn, queues):
def flush_queues(rmq_channel: BlockingChannel, redis_conn: Redis, queues: List[str]):
try:
channel = rmq_conn.channel()
for queue in queues:
try:
channel.queue_purge(queue)
rmq_channel.queue_purge(queue)
except ValueError:
logging.exception(f"Error while flushing the {queue} queue")
channel.close()
except pika.exceptions.AMQPError:
logging.exception("Error while flushing task queue")
for key in redis_conn.scan_iter(PREFIX + "*"):
Expand Down
24 changes: 11 additions & 13 deletions tests/test_taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Worker,
Dataset,
Task,
get_rabbitmq_connection,
get_rabbitmq_channel,
dataset_from_collection_id,
declare_rabbitmq_queue,
flush_queues,
Expand Down Expand Up @@ -49,8 +49,7 @@ def test_task_queue(self):
"payload": {},
"priority": priority,
}
connection = get_rabbitmq_connection()
channel = connection.channel()
channel = get_rabbitmq_channel()
declare_rabbitmq_queue(channel, test_queue_name)
channel.queue_purge(test_queue_name)
channel.basic_publish(
Expand Down Expand Up @@ -84,7 +83,7 @@ def test_task_queue(self):
assert task.get_retry_count(conn) == 1

with patch("servicelayer.settings.WORKER_RETRY", 0):
channel = connection.channel()
channel = get_rabbitmq_channel()
channel.queue_purge(test_queue_name)
channel.basic_publish(
properties=pika.BasicProperties(priority=priority),
Expand Down Expand Up @@ -138,8 +137,7 @@ def test_task_that_shouldnt_execute(self, mock_should_execute):
"collection_id": 2,
}

connection = get_rabbitmq_connection()
channel = connection.channel()
channel = get_rabbitmq_channel()
declare_rabbitmq_queue(channel, test_queue_name)
channel.queue_purge(test_queue_name)
channel.basic_publish(
Expand Down Expand Up @@ -171,16 +169,14 @@ def did_nack():
return_value=None,
) as dispatch_fn:
with patch.object(
pika.channel.Channel,
channel,
attribute="basic_nack",
return_value=None,
) as nack_fn:
worker.process(blocking=False)
nack_fn.assert_any_call(delivery_tag=1, multiple=False, requeue=True)
nack_fn.assert_called_once()
dispatch_fn.assert_not_called()

channel.close()

status = dataset.get_active_dataset_status(conn=conn)
stage = status["datasets"]["2"]["stages"][0]
assert stage["pending"] == 1
Expand All @@ -190,14 +186,16 @@ def did_nack():

def test_get_priority_bucket():
redis = get_fakeredis()
rmq = get_rabbitmq_connection()
flush_queues(rmq, redis, ["index"])
rmq_channel = get_rabbitmq_channel()
rmq_channel.queue_delete("index")
declare_rabbitmq_queue(rmq_channel, "index")
flush_queues(rmq_channel, redis, ["index"])
collection_id = 1

assert get_task_count(collection_id, redis) == 0
assert get_priority(collection_id, redis) in (7, 8)

queue_task(rmq, redis, collection_id, "index")
queue_task(rmq_channel, redis, collection_id, "index")

assert get_task_count(collection_id, redis) == 1
assert get_priority(collection_id, redis) in (7, 8)
Expand Down
Loading