Skip to content

Commit 50d70f0

Browse files
authored
fix(redis): Prevent counter corruption from concurrent mark handled in Redis RQ (#1878)
### Description - Fixes counter corruption on concurrent `mark_request_as_handled` calls for the same request. When two coroutines concurrently called `mark_request_as_handled` for the same request, both could pass the `hexists` check before either executed the pipeline, causing counters to be updated multiple times for the same request, breaking the queue operation. ### Issues - Closes: #1873 ### Testing - Added a test to verify that counters are updated correctly during concurrent execution. - Added a test to verify that the request is correctly restored to `in_progress` after failure.
1 parent 095d6cb commit 50d70f0

2 files changed

Lines changed: 98 additions & 25 deletions

File tree

src/crawlee/storage_clients/_redis/_request_queue_client.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -396,41 +396,56 @@ async def get_request(self, unique_key: str) -> Request | None:
396396
@retry_on_error(RedisError)
397397
@override
398398
async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None:
399-
# Check if the request is in progress.
400-
check_in_progress = await await_redis_response(self._redis.hexists(self._in_progress_key, request.unique_key))
401-
if not check_in_progress:
399+
# `hdel` is an atomic operation, so we can be sure that if it returns 1, the request was in progress and is now
400+
# removed from `in_progress`.
401+
deleted = await await_redis_response(self._redis.hdel(self._in_progress_key, request.unique_key))
402+
403+
# If not deleted, the request was not in progress.
404+
if not deleted:
402405
logger.warning(f'Marking request {request.unique_key} as handled that is not in progress.')
403406
return None
404407

405408
# Update the request's handled_at timestamp.
406409
if request.handled_at is None:
407410
request.handled_at = datetime.now(timezone.utc)
408411

409-
async with self._get_pipeline() as pipe:
410-
if self._dedup_strategy == 'default':
411-
await await_redis_response(pipe.sadd(self._handled_set_key, request.unique_key))
412-
await await_redis_response(pipe.srem(self._pending_set_key, request.unique_key))
413-
elif self._dedup_strategy == 'bloom':
414-
await await_redis_response(pipe.bf().add(self._handled_filter_key, request.unique_key))
412+
try:
413+
async with self._get_pipeline() as pipe:
414+
if self._dedup_strategy == 'default':
415+
await await_redis_response(pipe.sadd(self._handled_set_key, request.unique_key))
416+
await await_redis_response(pipe.srem(self._pending_set_key, request.unique_key))
417+
elif self._dedup_strategy == 'bloom':
418+
await await_redis_response(pipe.bf().add(self._handled_filter_key, request.unique_key))
415419

416-
await await_redis_response(pipe.hdel(self._in_progress_key, request.unique_key))
417-
await await_redis_response(pipe.hset(self._data_key, request.unique_key, request.model_dump_json()))
420+
await await_redis_response(pipe.hset(self._data_key, request.unique_key, request.model_dump_json()))
418421

419-
await self._update_metadata(
420-
pipe,
421-
**_QueueMetadataUpdateParams(
422-
update_accessed_at=True,
423-
update_modified_at=True,
424-
delta_handled_request_count=1,
425-
delta_pending_request_count=-1,
426-
),
427-
)
422+
await self._update_metadata(
423+
pipe,
424+
**_QueueMetadataUpdateParams(
425+
update_accessed_at=True,
426+
update_modified_at=True,
427+
delta_handled_request_count=1,
428+
delta_pending_request_count=-1,
429+
),
430+
)
428431

429-
return ProcessedRequest(
430-
unique_key=request.unique_key,
431-
was_already_present=True,
432-
was_already_handled=True,
433-
)
432+
return ProcessedRequest(
433+
unique_key=request.unique_key,
434+
was_already_present=True,
435+
was_already_handled=True,
436+
)
437+
except Exception:
438+
blocked_until = int(datetime.now(tz=timezone.utc).timestamp() * 1000) + self._BLOCK_REQUEST_TIME
439+
# If we fail to mark the request as handled after removing it from in_progress, we restore request in
440+
# `in_progress` hash.
441+
await await_redis_response(
442+
self._redis.hset(
443+
self._in_progress_key,
444+
request.unique_key,
445+
json.dumps({'client_id': self.client_key, 'blocked_until_timestamp': blocked_until}),
446+
)
447+
)
448+
raise
434449

435450
@retry_on_error(RedisError)
436451
@override

tests/unit/storage_clients/_redis/test_redis_rq_client.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,61 @@ async def test_get_metadata_does_not_retry_on_unexpected_exception(rq_client: Re
292292

293293
# Verify that retry logic was not attempted
294294
assert mock_sleep.call_count == 0
295+
296+
297+
async def test_mark_request_as_handled_concurrent_no_double_decrement(
298+
rq_client: RedisRequestQueueClient,
299+
) -> None:
300+
"""Test that concurrent calls to mark_request_as_handled decrement pending_request_count exactly once."""
301+
request = Request.from_url('https://example.com/concurrent')
302+
await rq_client.add_batch_of_requests([request])
303+
304+
fetched = await rq_client.fetch_next_request()
305+
assert fetched is not None
306+
307+
results = await asyncio.gather(
308+
rq_client.mark_request_as_handled(fetched),
309+
rq_client.mark_request_as_handled(fetched),
310+
rq_client.mark_request_as_handled(fetched),
311+
rq_client.mark_request_as_handled(fetched),
312+
rq_client.mark_request_as_handled(fetched),
313+
)
314+
315+
successful = [result for result in results if result is not None]
316+
assert len(successful) == 1
317+
318+
metadata = await rq_client.get_metadata()
319+
assert metadata.pending_request_count == 0
320+
assert metadata.handled_request_count == 1
321+
322+
323+
async def test_mark_request_as_handled_restores_in_progress_on_pipeline_failure(
324+
rq_client: RedisRequestQueueClient,
325+
) -> None:
326+
"""Test that 'request' is restored to 'in_progress' when the pipeline fails after 'hdel'."""
327+
request = Request.from_url('https://example.com/restore')
328+
await rq_client.add_batch_of_requests([request])
329+
330+
fetched = await rq_client.fetch_next_request()
331+
assert fetched is not None
332+
333+
mock_pipe = MagicMock()
334+
mock_pipe.execute = AsyncMock(side_effect=RedisError('connection lost'))
335+
336+
mock_pipeline_ctx = MagicMock()
337+
mock_pipeline_ctx.__aenter__ = AsyncMock(return_value=mock_pipe)
338+
mock_pipeline_ctx.__aexit__ = AsyncMock(return_value=None)
339+
340+
with (
341+
patch('crawlee._utils.retry._retry_sleep', new_callable=AsyncMock),
342+
patch.object(rq_client.redis, 'pipeline', return_value=mock_pipeline_ctx),
343+
pytest.raises(RedisError),
344+
):
345+
await rq_client.mark_request_as_handled(fetched)
346+
347+
in_progress = await await_redis_response(rq_client.redis.hexists(rq_client._in_progress_key, fetched.unique_key))
348+
assert in_progress
349+
350+
metadata = await rq_client.get_metadata()
351+
assert metadata.pending_request_count == 1
352+
assert metadata.handled_request_count == 0

0 commit comments

Comments
 (0)