Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion django_tasks/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, TypeVar

from asgiref.sync import sync_to_async
from django.conf import settings
from django.core.checks import messages
from django.db import connections
from django.utils import timezone
Expand Down Expand Up @@ -83,7 +84,11 @@ def validate_task(self, task: Task) -> None:
if not self.supports_defer and task.run_after is not None:
raise InvalidTaskError("Backend does not support run_after")

if task.run_after is not None and not timezone.is_aware(task.run_after):
if (
settings.USE_TZ
and task.run_after is not None
and not timezone.is_aware(task.run_after)
):
raise InvalidTaskError("run_after must be an aware datetime")

if self.queues and task.queue_name not in self.queues:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime

from django.conf import settings
from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.migrations.state import StateApps
Expand Down Expand Up @@ -55,7 +56,12 @@ class Migration(migrations.Migration):
name="run_after",
field=models.DateTimeField(
default=datetime.datetime(
9999, 1, 1, 0, 0, tzinfo=datetime.timezone.utc
9999,
1,
1,
0,
0,
tzinfo=datetime.timezone.utc if settings.USE_TZ else None,
),
verbose_name="run after",
),
Expand Down
12 changes: 9 additions & 3 deletions django_tasks/backends/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar

import django
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
from django.db import models
from django.db.models import F, Q
Expand Down Expand Up @@ -48,7 +49,10 @@ def __class_getitem__(cls, _):
return cls


DATE_MAX = datetime.datetime(9999, 1, 1, tzinfo=datetime.timezone.utc)
def get_date_max() -> datetime.datetime:
return datetime.datetime(
9999, 1, 1, tzinfo=datetime.timezone.utc if settings.USE_TZ else None
)


class DBTaskResultQuerySet(models.QuerySet):
Expand All @@ -58,7 +62,9 @@ def ready(self) -> "DBTaskResultQuerySet":
"""
return self.filter(
status=ResultStatus.READY,
).filter(models.Q(run_after=DATE_MAX) | models.Q(run_after__lte=timezone.now()))
).filter(
models.Q(run_after=get_date_max()) | models.Q(run_after__lte=timezone.now())
)

def succeeded(self) -> "DBTaskResultQuerySet":
return self.filter(status=ResultStatus.SUCCEEDED)
Expand Down Expand Up @@ -157,7 +163,7 @@ def task(self) -> Task[P, T]:
return task.using(
priority=self.priority,
queue_name=self.queue_name,
run_after=None if self.run_after == DATE_MAX else self.run_after,
run_after=None if self.run_after == get_date_max() else self.run_after,
backend=self.backend_name,
)

Expand Down
4 changes: 2 additions & 2 deletions django_tasks/backends/database/signal_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from django.db.models.signals import pre_save
from django.dispatch import receiver

from .models import DATE_MAX, DBTaskResult
from .models import DBTaskResult, get_date_max


@receiver(pre_save, sender=DBTaskResult)
def set_run_after(sender: Any, instance: DBTaskResult, **kwargs: Any) -> None:
if instance.run_after is None:
instance.run_after = DATE_MAX
instance.run_after = get_date_max()
40 changes: 39 additions & 1 deletion tests/tests/test_database_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import sys
import time
import uuid
import warnings
from collections import Counter
from contextlib import redirect_stderr
from datetime import timedelta
from datetime import datetime, timedelta
from functools import partial
from io import StringIO
from typing import Any, List, Optional, Sequence, Union, cast
Expand Down Expand Up @@ -436,6 +437,43 @@ def test_index_scan_for_ready(self) -> None:
else:
self.fail("Unknown database engine")

def test_run_after_tz(self) -> None:
for use_tz in [True, False]:
with self.subTest(use_tz=use_tz):
with override_settings(USE_TZ=use_tz):
result = test_tasks.noop_task.enqueue()
self.assertIsNone(
DBTaskResult.objects.get(id=result.id).task.run_after
)

def test_run_after_null_0016_migration(self) -> None:
from datetime import timezone

for use_tz in [True, False]:
with self.subTest(use_tz=use_tz):
with override_settings(USE_TZ=use_tz):
result = test_tasks.noop_task.enqueue()

db_result = DBTaskResult.objects.get(id=result.id)

# Literal taken from migration
db_result.run_after = datetime(
9999,
1,
1,
tzinfo=timezone.utc if use_tz else None,
)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", module="django.db", category=RuntimeWarning
)
db_result.save()

self.assertIsNone(
DBTaskResult.objects.get(id=result.id).task.run_after
)


@override_settings(
TASKS={
Expand Down