Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List, Optional

from django.core.management.base import BaseCommand
from django.db import connections
from django.db.utils import OperationalError

from django_tasks import DEFAULT_TASK_BACKEND_ALIAS, tasks
Expand Down Expand Up @@ -97,6 +98,9 @@ def start(self) -> None:
finally:
self.running_task = False

for conn in connections.all(initialized_only=True):
conn.close()

if self.batch and task_result is None:
# If we're running in "batch" mode, terminate the loop (and thus the worker)
return None
Expand Down
20 changes: 10 additions & 10 deletions tests/tests/test_database_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def test_run_enqueued_task(self) -> None:

self.assertEqual(result.status, ResultStatus.NEW)

with self.assertNumQueries(8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker()

self.assertEqual(result.status, ResultStatus.NEW)
Expand All @@ -287,7 +287,7 @@ def test_batch_processes_all_tasks(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 4)

with self.assertNumQueries(23):
with self.assertNumQueries(27 if connection.vendor == "mysql" else 23):
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -308,7 +308,7 @@ def test_doesnt_process_different_queue(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker(queue_name=result.task.queue_name)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -323,7 +323,7 @@ def test_process_all_queues(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker(queue_name="*")

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -332,7 +332,7 @@ def test_failing_task(self) -> None:
result = test_tasks.failing_task_value_error.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker()

self.assertEqual(result.status, ResultStatus.NEW)
Expand All @@ -358,9 +358,9 @@ def test_complex_exception(self) -> None:
result = test_tasks.complex_exception.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(8), self.assertLogs(
"django_tasks.backends.database", level="ERROR"
):
with self.assertNumQueries(
9 if connection.vendor == "mysql" else 8
), self.assertLogs("django_tasks.backends.database", level="ERROR"):
self.run_worker()

self.assertEqual(result.status, ResultStatus.NEW)
Expand All @@ -387,7 +387,7 @@ def test_doesnt_process_different_backend(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker(backend_name=result.backend)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand Down Expand Up @@ -447,7 +447,7 @@ def test_run_after(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(8):
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand Down