diff --git a/server/asyncio_extensions.py b/server/asyncio_extensions.py index 4a44d214e..b989a1a04 100644 --- a/server/asyncio_extensions.py +++ b/server/asyncio_extensions.py @@ -89,11 +89,14 @@ def _synchronize( lock: Optional[asyncio.Lock] = None ) -> AsyncFunc: """Wrap an async function with an async lock.""" - if lock is None: - lock = asyncio.Lock() - @wraps(function) async def wrapped(*args, **kwargs): + nonlocal lock + + # During testing, functions are called from multiple loops + if lock is None or lock._loop != asyncio.get_event_loop(): + lock = asyncio.Lock() + async with lock: return await function(*args, **kwargs) diff --git a/tests/unit_tests/test_matchmaker_queue.py b/tests/unit_tests/test_matchmaker_queue.py index 6c0d77f37..4f1f7d282 100644 --- a/tests/unit_tests/test_matchmaker_queue.py +++ b/tests/unit_tests/test_matchmaker_queue.py @@ -1,7 +1,9 @@ import asyncio import functools +import time from concurrent.futures import CancelledError, TimeoutError +import mock import pytest from hypothesis import given from hypothesis import strategies as st @@ -406,3 +408,33 @@ async def find_matches(): matchmaker_queue.on_match_found.assert_called_once_with( s2, s3, matchmaker_queue ) + + +@pytest.mark.asyncio +async def test_find_matches_synchronized(queue_factory): + is_matching = False + + def make_matches(*args): + nonlocal is_matching + + assert not is_matching, "Function call not synchronized" + is_matching = True + + time.sleep(0.2) + + is_matching = False + return [] + + with mock.patch( + "server.matchmaker.matchmaker_queue.make_matches", + make_matches + ): + queues = [queue_factory(f"Queue{i}") for i in range(5)] + # Ensure that find_matches does not short circuit + for queue in queues: + queue._queue = {mock.Mock(): 1, mock.Mock(): 2} + queue.find_teams = mock.Mock() + + await asyncio.gather(*[ + queue.find_matches() for queue in queues + ])