Skip to content

Commit

Permalink
Fix(Arq): fix integration with Worker settings as a dict (#3742)
Browse files Browse the repository at this point in the history
  • Loading branch information
saber-solooki authored Nov 6, 2024
1 parent 24e5359 commit c2dfbcc
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 14 deletions.
11 changes: 11 additions & 0 deletions sentry_sdk/integrations/arq.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ def _sentry_create_worker(*args, **kwargs):
# type: (*Any, **Any) -> Worker
settings_cls = args[0]

if isinstance(settings_cls, dict):
if "functions" in settings_cls:
settings_cls["functions"] = [
_get_arq_function(func) for func in settings_cls["functions"]
]
if "cron_jobs" in settings_cls:
settings_cls["cron_jobs"] = [
_get_arq_cron_job(cron_job)
for cron_job in settings_cls["cron_jobs"]
]

if hasattr(settings_cls, "functions"):
settings_cls.functions = [
_get_arq_function(func) for func in settings_cls.functions
Expand Down
113 changes: 99 additions & 14 deletions tests/integrations/arq/test_arq.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,65 @@ class WorkerSettings:
return inner


@pytest.fixture
def init_arq_with_dict_settings(sentry_init):
def inner(
cls_functions=None,
cls_cron_jobs=None,
kw_functions=None,
kw_cron_jobs=None,
allow_abort_jobs_=False,
):
cls_functions = cls_functions or []
cls_cron_jobs = cls_cron_jobs or []

kwargs = {}
if kw_functions is not None:
kwargs["functions"] = kw_functions
if kw_cron_jobs is not None:
kwargs["cron_jobs"] = kw_cron_jobs

sentry_init(
integrations=[ArqIntegration()],
traces_sample_rate=1.0,
send_default_pii=True,
)

server = FakeRedis()
pool = ArqRedis(pool_or_conn=server.connection_pool)

worker_settings = {
"functions": cls_functions,
"cron_jobs": cls_cron_jobs,
"redis_pool": pool,
"allow_abort_jobs": allow_abort_jobs_,
}

if not worker_settings["functions"]:
del worker_settings["functions"]
if not worker_settings["cron_jobs"]:
del worker_settings["cron_jobs"]

worker = arq.worker.create_worker(worker_settings, **kwargs)

return pool, worker

return inner


@pytest.mark.asyncio
async def test_job_result(init_arq):
@pytest.mark.parametrize(
"init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
)
async def test_job_result(init_arq_settings, request):
async def increase(ctx, num):
return num + 1

init_fixture_method = request.getfixturevalue(init_arq_settings)

increase.__qualname__ = increase.__name__

pool, worker = init_arq([increase])
pool, worker = init_fixture_method([increase])

job = await pool.enqueue_job("increase", 3)

Expand All @@ -105,14 +156,19 @@ async def increase(ctx, num):


@pytest.mark.asyncio
async def test_job_retry(capture_events, init_arq):
@pytest.mark.parametrize(
"init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
)
async def test_job_retry(capture_events, init_arq_settings, request):
async def retry_job(ctx):
if ctx["job_try"] < 2:
raise arq.worker.Retry

init_fixture_method = request.getfixturevalue(init_arq_settings)

retry_job.__qualname__ = retry_job.__name__

pool, worker = init_arq([retry_job])
pool, worker = init_fixture_method([retry_job])

job = await pool.enqueue_job("retry_job")

Expand All @@ -139,11 +195,18 @@ async def retry_job(ctx):
"source", [("cls_functions", "cls_cron_jobs"), ("kw_functions", "kw_cron_jobs")]
)
@pytest.mark.parametrize("job_fails", [True, False], ids=["error", "success"])
@pytest.mark.parametrize(
"init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
)
@pytest.mark.asyncio
async def test_job_transaction(capture_events, init_arq, source, job_fails):
async def test_job_transaction(
capture_events, init_arq_settings, source, job_fails, request
):
async def division(_, a, b=0):
return a / b

init_fixture_method = request.getfixturevalue(init_arq_settings)

division.__qualname__ = division.__name__

cron_func = async_partial(division, a=1, b=int(not job_fails))
Expand All @@ -152,7 +215,9 @@ async def division(_, a, b=0):
cron_job = cron(cron_func, minute=0, run_at_startup=True)

functions_key, cron_jobs_key = source
pool, worker = init_arq(**{functions_key: [division], cron_jobs_key: [cron_job]})
pool, worker = init_fixture_method(
**{functions_key: [division], cron_jobs_key: [cron_job]}
)

events = capture_events()

Expand Down Expand Up @@ -213,12 +278,17 @@ async def division(_, a, b=0):


@pytest.mark.parametrize("source", ["cls_functions", "kw_functions"])
@pytest.mark.parametrize(
"init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
)
@pytest.mark.asyncio
async def test_enqueue_job(capture_events, init_arq, source):
async def test_enqueue_job(capture_events, init_arq_settings, source, request):
async def dummy_job(_):
pass

pool, _ = init_arq(**{source: [dummy_job]})
init_fixture_method = request.getfixturevalue(init_arq_settings)

pool, _ = init_fixture_method(**{source: [dummy_job]})

events = capture_events()

Expand All @@ -236,13 +306,18 @@ async def dummy_job(_):


@pytest.mark.asyncio
async def test_execute_job_without_integration(init_arq):
@pytest.mark.parametrize(
"init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
)
async def test_execute_job_without_integration(init_arq_settings, request):
async def dummy_job(_ctx):
pass

init_fixture_method = request.getfixturevalue(init_arq_settings)

dummy_job.__qualname__ = dummy_job.__name__

pool, worker = init_arq([dummy_job])
pool, worker = init_fixture_method([dummy_job])
# remove the integration to trigger the edge case
get_client().integrations.pop("arq")

Expand All @@ -254,12 +329,17 @@ async def dummy_job(_ctx):


@pytest.mark.parametrize("source", ["cls_functions", "kw_functions"])
@pytest.mark.parametrize(
"init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
)
@pytest.mark.asyncio
async def test_span_origin_producer(capture_events, init_arq, source):
async def test_span_origin_producer(capture_events, init_arq_settings, source, request):
async def dummy_job(_):
pass

pool, _ = init_arq(**{source: [dummy_job]})
init_fixture_method = request.getfixturevalue(init_arq_settings)

pool, _ = init_fixture_method(**{source: [dummy_job]})

events = capture_events()

Expand All @@ -272,13 +352,18 @@ async def dummy_job(_):


@pytest.mark.asyncio
async def test_span_origin_consumer(capture_events, init_arq):
@pytest.mark.parametrize(
"init_arq_settings", ["init_arq", "init_arq_with_dict_settings"]
)
async def test_span_origin_consumer(capture_events, init_arq_settings, request):
async def job(ctx):
pass

init_fixture_method = request.getfixturevalue(init_arq_settings)

job.__qualname__ = job.__name__

pool, worker = init_arq([job])
pool, worker = init_fixture_method([job])

job = await pool.enqueue_job("retry_job")

Expand Down

0 comments on commit c2dfbcc

Please sign in to comment.