Skip to content

Commit

Permalink
fix: Audit Log records don't get created with threaded task processing (
Browse files Browse the repository at this point in the history
  • Loading branch information
khvn26 authored Nov 13, 2023
1 parent 65351e2 commit 716b228
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 87 deletions.
5 changes: 5 additions & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,11 @@
"handlers": ["console"],
"propagate": False,
},
"webhooks": {
"level": LOG_LEVEL,
"handlers": ["console"],
"propagate": False,
},
},
}

Expand Down
16 changes: 15 additions & 1 deletion api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ pytest-cov = "~4.1.0"
datamodel-code-generator = "~0.22"
requests-mock = "^1.11.0"
pdbpp = "^0.10.3"
django-capture-on-commit-callbacks = "^1.11.0"

[build-system]
requires = ["poetry-core>=1.5.0"]
Expand Down
216 changes: 137 additions & 79 deletions api/task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,107 +6,165 @@
from threading import Thread

from django.conf import settings
from django.db.transaction import on_commit
from django.utils import timezone

from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError
from task_processor.models import RecurringTask, Task, TaskPriority
from task_processor.task_registry import register_task
from task_processor.task_run_method import TaskRunMethod

P = typing.ParamSpec("P")

logger = logging.getLogger(__name__)


def register_task_handler(
task_name: str = None,
queue_size: int = None,
priority: TaskPriority = TaskPriority.NORMAL,
):
def decorator(f: typing.Callable):
nonlocal task_name
class TaskHandler(typing.Generic[P]):
__slots__ = (
"unwrapped",
"queue_size",
"priority",
"transaction_on_commit",
"task_identifier",
)

unwrapped: typing.Callable[P, None]

def __init__(
self,
f: typing.Callable[P, None],
*,
task_name: str | None = None,
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
) -> None:
self.unwrapped = f
self.queue_size = queue_size
self.priority = priority
self.transaction_on_commit = transaction_on_commit

task_name = task_name or f.__name__
task_module = getmodule(f).__name__.rsplit(".")[-1]
task_identifier = f"{task_module}.{task_name}"
self.task_identifier = task_identifier = f"{task_module}.{task_name}"
register_task(task_identifier, f)

def delay(
*,
delay_until: datetime = None,
args: typing.Tuple = (),
kwargs: typing.Dict = None,
) -> typing.Optional[Task]:
logger.debug("Request to run task '%s' asynchronously.", task_identifier)

kwargs = kwargs or dict()

if delay_until and settings.TASK_RUN_METHOD != TaskRunMethod.TASK_PROCESSOR:
logger.warning(
"Cannot schedule tasks to run in the future without task processor."
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None:
_validate_inputs(*args, **kwargs)
return self.unwrapped(*args, **kwargs)

def delay(
self,
*,
delay_until: datetime | None = None,
# TODO @khvn26 consider typing `args` and `kwargs` with `ParamSpec`
# (will require a change to the signature)
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
) -> Task | None:
logger.debug("Request to run task '%s' asynchronously.", self.task_identifier)

kwargs = kwargs or {}

if delay_until and settings.TASK_RUN_METHOD != TaskRunMethod.TASK_PROCESSOR:
logger.warning(
"Cannot schedule tasks to run in the future without task processor."
)
return

if settings.TASK_RUN_METHOD == TaskRunMethod.SYNCHRONOUSLY:
_validate_inputs(*args, **kwargs)
self.unwrapped(*args, **kwargs)
elif settings.TASK_RUN_METHOD == TaskRunMethod.SEPARATE_THREAD:
logger.debug("Running task '%s' in separate thread", self.task_identifier)
self.run_in_thread(args=args, kwargs=kwargs)
else:
logger.debug("Creating task for function '%s'...", self.task_identifier)
try:
task = Task.create(
task_identifier=self.task_identifier,
scheduled_for=delay_until or timezone.now(),
priority=self.priority,
queue_size=self.queue_size,
args=args,
kwargs=kwargs,
)
except TaskQueueFullError as e:
logger.warning(e)
return

if settings.TASK_RUN_METHOD == TaskRunMethod.SYNCHRONOUSLY:
_validate_inputs(*args, **kwargs)
f(*args, **kwargs)
elif settings.TASK_RUN_METHOD == TaskRunMethod.SEPARATE_THREAD:
logger.debug("Running task '%s' in separate thread", task_identifier)
run_in_thread(args=args, kwargs=kwargs)
else:
logger.debug("Creating task for function '%s'...", task_identifier)
try:
task = Task.create(
task_identifier=task_identifier,
scheduled_for=delay_until or timezone.now(),
priority=priority,
queue_size=queue_size,
args=args,
kwargs=kwargs,
)
except TaskQueueFullError as e:
logger.warning(e)
return

task.save()
return task

def run_in_thread(*, args: typing.Tuple = (), kwargs: typing.Dict = None):
logger.info("Running function %s in unmanaged thread.", f.__name__)
_validate_inputs(*args, **kwargs)
Thread(target=f, args=args, kwargs=kwargs, daemon=True).start()

def _wrapper(*args, **kwargs):
"""
Execute the function after validating the arguments. Ensures that, in unit testing,
the arguments are validated to prevent issues with serialization in an environment
that utilises the task processor.
"""
_validate_inputs(*args, **kwargs)
return f(*args, **kwargs)

_wrapper.delay = delay
_wrapper.run_in_thread = run_in_thread
_wrapper.task_identifier = task_identifier

# patch the original unwrapped function onto the wrapped version for testing
_wrapper.unwrapped = f

return _wrapper
task.save()
return task

def run_in_thread(
self,
*,
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
) -> None:
_validate_inputs(*args, **kwargs)
thread = Thread(target=self.unwrapped, args=args, kwargs=kwargs, daemon=True)

def _start() -> None:
logger.info(
"Running function %s in unmanaged thread.", self.unwrapped.__name__
)
thread.start()

if self.transaction_on_commit:
return on_commit(_start)
return _start()


def register_task_handler( # noqa: C901
*,
task_name: str | None = None,
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]:
"""
Turn a function into an asynchronous task.
:param str task_name: task name. Defaults to function name.
:param int queue_size: (`TASK_PROCESSOR` task run method only)
max queue size for the task. Task runs exceeding the max size get dropped by
the task processor Defaults to `None` (infinite).
:param TaskPriority priority: task priority.
:param bool transaction_on_commit: (`SEPARATE_THREAD` task run method only)
Whether to wrap the task call in `transanction.on_commit`. Defaults to `True`.
We need this for the task to be able to access data committed with the current
transaction. If the task is invoked outside of a transaction, it will start
immediately.
Pass `False` if you want the task to start immediately regardless of current
transaction.
:rtype: TaskHandler
"""

def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]:
return TaskHandler(
f,
task_name=task_name,
queue_size=queue_size,
priority=priority,
transaction_on_commit=transaction_on_commit,
)

return decorator
return wrapper


def register_recurring_task(
run_every: timedelta,
task_name: str = None,
args: typing.Tuple = (),
kwargs: typing.Dict = None,
first_run_time: time = None,
):
task_name: str | None = None,
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
first_run_time: time | None = None,
) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
if not os.environ.get("RUN_BY_PROCESSOR"):
# Do not register recurring tasks if not invoked by task processor
return lambda f: f

def decorator(f: typing.Callable):
def decorator(f: typing.Callable[..., None]) -> RecurringTask:
nonlocal task_name

task_name = task_name or f.__name__
Expand All @@ -118,8 +176,8 @@ def decorator(f: typing.Callable):
task, _ = RecurringTask.objects.update_or_create(
task_identifier=task_identifier,
defaults={
"serialized_args": RecurringTask.serialize_data(args or tuple()),
"serialized_kwargs": RecurringTask.serialize_data(kwargs or dict()),
"serialized_args": RecurringTask.serialize_data(args or ()),
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
"run_every": run_every,
"first_run_time": first_run_time,
},
Expand All @@ -129,9 +187,9 @@ def decorator(f: typing.Callable):
return decorator


def _validate_inputs(*args, **kwargs):
def _validate_inputs(*args: typing.Any, **kwargs: typing.Any) -> None:
try:
Task.serialize_data(args or tuple())
Task.serialize_data(kwargs or dict())
Task.serialize_data(args or ())
Task.serialize_data(kwargs or {})
except TypeError as e:
raise InvalidArgumentsError("Inputs are not serializable.") from e
Loading

3 comments on commit 716b228

@vercel
Copy link

@vercel vercel bot commented on 716b228 Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vercel
Copy link

@vercel vercel bot commented on 716b228 Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vercel
Copy link

@vercel vercel bot commented on 716b228 Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

docs – ./docs

docs-git-main-flagsmith.vercel.app
docs.bullet-train.io
docs-flagsmith.vercel.app
docs.flagsmith.com

Please sign in to comment.