Skip to content

Commit

Permalink
Migrations for commit 982a38e from main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
deligianp committed Sep 10, 2024
2 parents 5c81fab + 982a38e commit f403ff4
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 55 deletions.
4 changes: 2 additions & 2 deletions schema-api/api/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django.db import models


class _TaskStatus(models.IntegerChoices):
class TaskStatus(models.IntegerChoices):
UNKNOWN = -1, 'UNKNOWN'
SUBMITTED = 0, 'SUBMITTED'
APPROVED = 1, 'APPROVED'
Expand Down Expand Up @@ -34,4 +34,4 @@ class ErrorMessages:
RESOURCE_SET_DISK_GB_MIN_VIOLATION = f'Amount of disk in GBs, must be at least 1GB'
RESOURCE_SET_RAM_GB_MIN_VIOLATION = f'Amount of RAM in GBs, must be at least 1GB'
TASK_STATUS_ENUM_VIOLATION = f'Task status must be any of the following ' \
f'values: {", ".join(_.label for _ in _TaskStatus)}'
f'values: {", ".join(_.label for _ in TaskStatus)}'
20 changes: 17 additions & 3 deletions schema-api/api/filtersets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from django.db.models import Q
from django.db.models import Q, OuterRef, Subquery
from django_filters import rest_framework as filters

from api.constants import _TaskStatus
from api.constants import TaskStatus
from api.models import StatusHistoryPoint


class TaskFilter(filters.FilterSet):
name = filters.CharFilter(field_name='name', lookup_expr='icontains')
status = filters.MultipleChoiceFilter(choices=_TaskStatus.choices)
status = filters.MultipleChoiceFilter(choices=[(c.label, c.value) for c in TaskStatus], method='filter_by_status')
after = filters.DateTimeFilter(field_name='submitted_at', lookup_expr='gte')
before = filters.DateTimeFilter(field_name='submitted_at', lookup_expr='lt')
order = filters.OrderingFilter(
Expand All @@ -22,3 +23,16 @@ def filter_by_search(self, queryset, name, value):
return queryset.filter(
Q(uuid__icontains=value) | Q(name__icontains=value)
)

def filter_by_status(self, queryset, name, value):
target_statuses = [TaskStatus[v.upper()].value for v in value]

latest_statuses = StatusHistoryPoint.objects.filter(
task=OuterRef('pk')
).order_by('-created_at')

tasks_with_latest_status = queryset.annotate(
latest_status=Subquery(latest_statuses.values('status')[:1])
)

return tasks_with_latest_status.filter(latest_status__in=target_statuses)
12 changes: 6 additions & 6 deletions schema-api/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.db import models
from django.db.models import CheckConstraint, Q, UniqueConstraint, F

from api.constants import MountPointTypes, _TaskStatus
from api.constants import MountPointTypes, TaskStatus
from util.constraints import ApplicationUniqueConstraint
from util.decorators import update_fields
from util.defaults import get_current_datetime
Expand Down Expand Up @@ -106,18 +106,18 @@ def outputs(self):
class StatusHistoryPoint(models.Model):
task = models.ForeignKey(Task, on_delete=models.CASCADE, related_name='status_history_points')
created_at = models.DateTimeField(default=get_current_datetime)
status = models.IntegerField(choices=_TaskStatus.choices)
status = models.IntegerField(choices=TaskStatus.choices)

class Meta:
constraints = [
CheckConstraint(
check=Q(status__in=[choice.value for choice in _TaskStatus]),
check=Q(status__in=[choice.value for choice in TaskStatus]),
name='status_history_enum'
),
]

def __str__(self):
return f'{self.task.uuid}: {_TaskStatus(self.status).label}({self.created_at.isoformat()})'
return f'{self.task.uuid}: {TaskStatus(self.status).label}({self.created_at.isoformat()})'


# A model modified in such way that save() method is in effect only when it is the initial save of the instance
Expand Down Expand Up @@ -236,8 +236,8 @@ class Meta:
]


class TempTag(models.Model):
tasks = models.ManyToManyField(Task, related_name='temptags')
class Tag(models.Model):
tasks = models.ManyToManyField(Task, related_name='tags')
value = models.CharField(max_length=255)

class Meta:
Expand Down
12 changes: 6 additions & 6 deletions schema-api/api/quotas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.db.models import Q, Sum

from api.constants import _TaskStatus
from api.constants import TaskStatus
from api.models import Task
from util.exceptions import ApplicationTaskQuotaDepletedError, ApplicationTaskQuotaExceedingRequestError

Expand Down Expand Up @@ -36,7 +36,7 @@ class DefaultQuotaPolicy(AbstractQuotaPolicy):

def _check_max_cpu_cores(self, task: Task):
allocated_cpu = Task.objects.filter(
~Q(status__in=(_TaskStatus.SUBMITTED, _TaskStatus.REJECTED)), pending=True, context=task.context
~Q(status__in=(TaskStatus.SUBMITTED, TaskStatus.REJECTED)), pending=True, context=task.context
).aggregate(Sum('resourceset__cpu_cores'))['resourceset__cpu_cores__sum'] or 0
if allocated_cpu >= task.context.quotas.max_cpu:
raise ApplicationTaskQuotaDepletedError('All CPU cores for this context are currently allocated')
Expand All @@ -49,19 +49,19 @@ def _check_max_cpu_cores(self, task: Task):
def _check_max_tasks(self, task: Task):
context_tasks_qs = Task.objects.filter(context=task.context)
n_completed_tasks = context_tasks_qs.filter(
Q(Q(status=_TaskStatus.COMPLETED) | Q(status=_TaskStatus.ERROR))).count()
Q(Q(status=TaskStatus.COMPLETED) | Q(status=TaskStatus.ERROR))).count()
if n_completed_tasks >= task.context.quotas.max_tasks:
raise ApplicationTaskQuotaDepletedError('All tasks reserved for this context have already been ran')

n_running_tasks = context_tasks_qs.filter(~Q(status__in=(_TaskStatus.REJECTED, _TaskStatus.SUBMITTED)),
n_running_tasks = context_tasks_qs.filter(~Q(status__in=(TaskStatus.REJECTED, TaskStatus.SUBMITTED)),
pending=True).count()
if n_completed_tasks + n_running_tasks >= task.context.quotas.max_tasks:
raise ApplicationTaskQuotaDepletedError(
'Currently running tasks have allocated the reserved number of tasks for this context')

def _check_max_ram_gb(self, task: Task):
allocated_ram_gb = Task.objects.filter(
~Q(status__in=(_TaskStatus.SUBMITTED, _TaskStatus.REJECTED)), pending=True, context=task.context
~Q(status__in=(TaskStatus.SUBMITTED, TaskStatus.REJECTED)), pending=True, context=task.context
).aggregate(Sum('resourceset__ram_gb'))['resourceset__ram_gb__sum'] or 0
if allocated_ram_gb >= task.context.quotas.max_ram_gb:
raise ApplicationTaskQuotaDepletedError('All RAM for this context is currently allocated')
Expand All @@ -72,7 +72,7 @@ def _check_max_ram_gb(self, task: Task):
)

def _check_max_active_tasks(self, task: Task):
if Task.objects.filter(~Q(status__in=(_TaskStatus.SUBMITTED, _TaskStatus.REJECTED)),
if Task.objects.filter(~Q(status__in=(TaskStatus.SUBMITTED, TaskStatus.REJECTED)),
pending=True, context=task.context).count() >= task.context.quotas.max_active_tasks:
raise ApplicationTaskQuotaDepletedError(
'Maximum number of active/concurrent tasks for this context is currently running')
2 changes: 1 addition & 1 deletion schema-api/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class TaskSerializer(BaseSerializer):
volumes = ModelMemberRelatedField(target_field_name='path', child=serializers.CharField(), allow_empty=False,
required=False)
tags = ModelMemberRelatedField(target_field_name='value', child=serializers.CharField(), allow_empty=False,
required=False, source='temptags')
required=False)


class TasksBasicListSerializer(serializers.Serializer):
Expand Down
27 changes: 15 additions & 12 deletions schema-api/api/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from django.utils import timezone

from api import taskapis
from api.constants import _TaskStatus
from api.constants import TaskStatus
from api.models import Task, Executor, Env, MountPoint, Volume, ResourceSet, ExecutorOutputLog, Context, \
Participation, StatusHistoryPoint, TempTag
Participation, StatusHistoryPoint, Tag
from api_auth.constants import AuthEntityType
from api_auth.models import AuthEntity
from quotas.evaluators import ActiveResourcesDbQuotasEvaluator, RequestedResourcesQuotasEvaluator, TasksQuotasEvaluator
Expand Down Expand Up @@ -47,13 +47,13 @@ def submit_task(self, *, executors: Iterable[Executor], **optional):
input_mount_points = optional.pop('inputs', None)
output_mount_points = optional.pop('outputs', None)
volumes = optional.pop('volumes', None)
tags = optional.pop('temptags', None)
tags = optional.pop('tags', None)
resource_set = optional.pop('resources', None)

task = Task.objects.create(context=self.context, user=self.auth_entity, **optional)

status_history_point_service = StatusHistoryPointService(task)
status_history_point_service.update_status(_TaskStatus.SUBMITTED)
status_history_point_service.update_status(TaskStatus.SUBMITTED)

i = 0
for executor in executors:
Expand Down Expand Up @@ -85,16 +85,16 @@ def submit_task(self, *, executors: Iterable[Executor], **optional):
if tags:
tag_set = set(tags)
for tag in tag_set:
temp_tag = TempTag.objects.get_or_create(value=tag)[0]
task.temptags.add(temp_tag)
temp_tag = Tag.objects.get_or_create(value=tag)[0]
task.tags.add(temp_tag)

quotas_service = QuotasService(task.context, task.user)
context_quotas, participation_quotas = quotas_service.get_qualified_quotas()
RequestedResourcesQuotasEvaluator.evaluate(context_quotas, participation_quotas, task)
TasksQuotasEvaluator.evaluate(context_quotas, participation_quotas, task)
ActiveResourcesDbQuotasEvaluator.evaluate(context_quotas, participation_quotas, task)

status_history_point_service.update_status(_TaskStatus.APPROVED)
status_history_point_service.update_status(TaskStatus.APPROVED)

if settings.TASK_API["TASK_API_CLASS"] and not settings.DISABLE_TASK_SCHEDULING:
task_api_class = taskapis.get_task_api_class()
Expand All @@ -106,7 +106,7 @@ def submit_task(self, *, executors: Iterable[Executor], **optional):
task.latest_update = timezone.now()
task.save()

status_history_point_service.update_status(_TaskStatus.SCHEDULED)
status_history_point_service.update_status(TaskStatus.SCHEDULED)
return task

@transaction.atomic
Expand All @@ -120,7 +120,7 @@ def _check_if_update_task(self, task):
task_info = task_api.get_task_info(task.task_id)

task_status = task_info['status']
if task_status in [_TaskStatus.COMPLETED, _TaskStatus.ERROR, _TaskStatus.CANCELED]:
if task_status in [TaskStatus.COMPLETED, TaskStatus.ERROR, TaskStatus.CANCELED]:
task.pending = False

status_history_point_service = StatusHistoryPointService(task)
Expand Down Expand Up @@ -149,7 +149,10 @@ def _check_if_update_task(self, task):
return task

def get_task(self, task_uuid: uuid.UUID):
task = Task.objects.get(context=self.context, uuid=task_uuid)
try:
task = Task.objects.get(context=self.context, uuid=task_uuid)
except Task.DoesNotExist as dne:
raise ApplicationNotFoundError(f'No task was found with UUID "{task_uuid}"') from dne

task = self._check_if_update_task(task)

Expand Down Expand Up @@ -191,7 +194,7 @@ def cancel_task(self, task_uuid: uuid.UUID) -> None:
raise ApplicationValidationError({'uuid': f'Task with UUID \'{task_uuid}\' has already terminated'})

status_history_point_service = StatusHistoryPointService(task)
status_history_point_service.update_status(_TaskStatus.CANCELED)
status_history_point_service.update_status(TaskStatus.CANCELED)
return

raise ApplicationValidationError({'uuid': f'Task with UUID \'{task_uuid}\' has already terminated'})
Expand All @@ -202,7 +205,7 @@ class StatusHistoryPointService:
def __init__(self, task: Task):
self.task = task

def update_status(self, status: _TaskStatus, update_time: datetime.datetime = None) -> StatusHistoryPoint:
def update_status(self, status: TaskStatus, update_time: datetime.datetime = None) -> StatusHistoryPoint:
update_time = update_time or timezone.now()
return StatusHistoryPoint.objects.create(task=self.task, status=status, created_at=update_time)

Expand Down
20 changes: 10 additions & 10 deletions schema-api/api/taskapis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django.conf import settings
from rest_framework import status

from api.constants import _TaskStatus
from api.constants import TaskStatus
from api.models import Task
from api.serializers import TaskSerializer

Expand Down Expand Up @@ -107,15 +107,15 @@ def get_executor_logs(self, task_id, executor_index: int = None, task_content=No
return result

TES_SCHEMA_STATUS_MAP = {
'UNKNOWN': _TaskStatus.UNKNOWN,
'INITIALIZING': _TaskStatus.INITIALIZING,
'QUEUED': _TaskStatus.INITIALIZING,
'RUNNING': _TaskStatus.RUNNING,
'PAUSED': _TaskStatus.RUNNING,
'COMPLETE': _TaskStatus.COMPLETED,
'EXECUTOR_ERROR': _TaskStatus.ERROR,
'SYSTEM_ERROR': _TaskStatus.ERROR,
'CANCELED': _TaskStatus.CANCELED
'UNKNOWN': TaskStatus.UNKNOWN,
'INITIALIZING': TaskStatus.INITIALIZING,
'QUEUED': TaskStatus.INITIALIZING,
'RUNNING': TaskStatus.RUNNING,
'PAUSED': TaskStatus.RUNNING,
'COMPLETE': TaskStatus.COMPLETED,
'EXECUTOR_ERROR': TaskStatus.ERROR,
'SYSTEM_ERROR': TaskStatus.ERROR,
'CANCELED': TaskStatus.CANCELED
}

def _get_task(self, task_id):
Expand Down
16 changes: 8 additions & 8 deletions schema-api/api/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from django.test import TestCase
from django.utils import timezone

from api.constants import _TaskStatus
from api.constants import TaskStatus
from api.models import Task, StatusHistoryPoint


Expand All @@ -18,7 +18,7 @@ def setUpTestData(cls):

def test_valid_save_accepts_all_provided_values(self):
dt = timezone.now()
status = _TaskStatus.SUBMITTED
status = TaskStatus.SUBMITTED
status_history_point = StatusHistoryPoint(task=self.task, created_at=dt, status=status)
status_history_point.save()
status_history_point.refresh_from_db()
Expand All @@ -28,29 +28,29 @@ def test_valid_save_accepts_all_provided_values(self):

def test_save_without_a_referenced_task_raises_error(self):
dt = timezone.now()
status = _TaskStatus.SUBMITTED
status = TaskStatus.SUBMITTED
status_history_point = StatusHistoryPoint(created_at=dt, status=status)
with self.assertRaises(IntegrityError):
status_history_point.save()

def test_save_with_a_referenced_task_as_none_raises_error(self):
dt = timezone.now()
status = _TaskStatus.SUBMITTED
status = TaskStatus.SUBMITTED
status_history_point = StatusHistoryPoint(task=None, created_at=dt, status=status)
with self.assertRaises(IntegrityError):
status_history_point.save()

def test_save_without_created_at_uses_timezone_now_default(self):
mocked_datetime = dt(2020, 1, 1, 1, 1, 1, tzinfo=datetime.timezone.utc)
status = _TaskStatus.SUBMITTED
status = TaskStatus.SUBMITTED
with patch('django.utils.timezone.now', return_value=mocked_datetime):
print(timezone.now())
status_history_point = StatusHistoryPoint.objects.create(task=self.task, status=status)
print(status_history_point.created_at)
self.assertEqual(status_history_point.created_at, mocked_datetime)

def test_save_with_created_at_as_none_raises_error(self):
status = _TaskStatus.SUBMITTED
status = TaskStatus.SUBMITTED
status_history_point = StatusHistoryPoint(task=self.task, created_at=None, status=status)
with self.assertRaises(IntegrityError):
status_history_point.save()
Expand All @@ -76,14 +76,14 @@ def test_save_with_invalid_task_status_choice_raises_error(self):

def test_str_returns_status_history_point_description(self):
dt = timezone.now()
status = _TaskStatus.SUBMITTED
status = TaskStatus.SUBMITTED
status_history_point = StatusHistoryPoint(task=self.task, created_at=dt, status=status)
self.assertEqual(f'{self.task.uuid}: {status.label}({dt.isoformat()})', str(status_history_point))

def test_deleting_referenced_task_deletes_status_history_point(self):
task = Task.objects.create(name='sample-task-2')
dt = timezone.now()
status = _TaskStatus.SUBMITTED
status = TaskStatus.SUBMITTED
status_history_point = StatusHistoryPoint(task=task, created_at=dt, status=status)
task.delete()
with self.assertRaises(StatusHistoryPoint.DoesNotExist):
Expand Down
4 changes: 2 additions & 2 deletions schema-api/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rest_framework.response import Response
from rest_framework.views import APIView

from api.constants import _TaskStatus
from api.constants import TaskStatus
from api.filtersets import TaskFilter
from api.models import Task
from api.serializers import TaskSerializer, TasksListQPSerializer, TasksBasicListSerializer, \
Expand Down Expand Up @@ -152,7 +152,7 @@ def get_serializer_class(self):
allow_blank=False, many=False, ),
OpenApiParameter('status', OpenApiTypes.STR, OpenApiParameter.QUERY,
description='Status to filter tasks on', required=False,
allow_blank=False, many=False, enum=[x.label for x in _TaskStatus]),
allow_blank=False, many=False, enum=[x.label for x in TaskStatus]),
OpenApiParameter('before', OpenApiTypes.DATETIME, OpenApiParameter.QUERY,
description='Retrieve tasks submitted before this date', required=False,
allow_blank=False, many=False),
Expand Down
6 changes: 3 additions & 3 deletions schema-api/migrations/api/0004_migrate_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.db import migrations

from api.constants import _TaskStatus
from api.constants import TaskStatus


def migrate_status_to_status_history_point(apps, schema_editor):
Expand All @@ -11,12 +11,12 @@ def migrate_status_to_status_history_point(apps, schema_editor):

for task in Task.objects.all():
new_status = None
for s in _TaskStatus:
for s in TaskStatus:
if s.label == task.status:
new_status = s
break
if new_status is None:
new_status = _TaskStatus.UNKNOWN
new_status = TaskStatus.UNKNOWN

created_at = task.latest_update or task.submitted_at
StatusHistoryPoint.objects.create(
Expand Down
Loading

0 comments on commit f403ff4

Please sign in to comment.