Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.

feat: Add background-task for kernel-pull-progress #477

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/477.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a background task to update the progress reporter with 'KernelPullProgressEvent', until image pulling is done.
80 changes: 77 additions & 3 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import trafaret as t
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from ..background import ProgressReporter

from ai.backend.common import redis, validators as tx
from ai.backend.common.docker import ImageRef
Expand All @@ -71,6 +72,7 @@
SessionStartedEvent,
SessionSuccessEvent,
SessionTerminatedEvent,
KernelPullProgressEvent,
)
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.utils import cancel_tasks, str_to_timedelta
Expand Down Expand Up @@ -519,8 +521,19 @@ async def _create(request: web.Request, params: Any) -> web.Response:
resp['status'] = 'PENDING'
resp['servicePorts'] = []
resp['created'] = True

if not params['enqueue_only']:
if params['enqueue_only']:
task_id = await root_ctx.background_task_manager.start(
functools.partial(
monitor_kernel_preparation,
kernel_id=kernel_id,
root_ctx=root_ctx,
app=request.app,
),
name='monitor-kernel-preparation',
)
resp['background_task'] = str(task_id)
return web.json_response(resp, status=201)
else:
app_ctx.pending_waits.add(current_task)
max_wait = params['max_wait_seconds']
try:
Expand All @@ -530,7 +543,18 @@ async def _create(request: web.Request, params: Any) -> web.Response:
else:
await start_event.wait()
except asyncio.TimeoutError:
task_id = await root_ctx.background_task_manager.start(
functools.partial(
monitor_kernel_preparation,
kernel_id=kernel_id,
root_ctx=root_ctx,
app=request.app,
),
name='monitor-kernel-preparation',
)
resp['background_task'] = str(task_id)
resp['status'] = 'TIMEOUT'
return web.json_response(resp, status=201)
else:
await asyncio.sleep(0.5)
async with root_ctx.db.begin_readonly() as conn:
Expand Down Expand Up @@ -1014,7 +1038,6 @@ async def create_cluster(request: web.Request, params: Any) -> web.Response:
resp['status'] = 'PENDING'
resp['servicePorts'] = []
resp['created'] = True

if not params['enqueue_only']:
app_ctx.pending_waits.add(current_task)
max_wait = params['max_wait_seconds']
Expand Down Expand Up @@ -1241,6 +1264,57 @@ async def handle_agent_heartbeat(
await root_ctx.registry.handle_heartbeat(source, event.agent_info)


async def monitor_kernel_preparation(
reporter: ProgressReporter,
kernel_id: uuid.UUID,
root_ctx: RootContext,
app: web.Application,
) -> None:
progress = [0, 0]

async def _get_status(kernel_id):
async with root_ctx.db.begin_readonly() as conn:
query = (
sa.select([
kernels.c.id,
kernels.c.status,
])
.select_from(kernels)
.where(kernels.c.id == kernel_id)
)
result = await conn.execute(query)

return result.first()

async def _update_progress(
app: web.Application,
source: AgentId,
event: KernelPullProgressEvent,
) -> None:
# update both current and total
progress[0] = int(event.current_progress)
progress[1] = int(event.total_progress)

progress_handler = root_ctx.event_dispatcher.subscribe(
KernelPullProgressEvent,
app,
_update_progress,
)
try:
while True:
result = await _get_status(kernel_id)
if result['status'] == KernelStatus.PREPARING:
await reporter.update(0)
if result['status'] == KernelStatus.RUNNING:
break
reporter.current_progress = progress[0]
reporter.total_progress = progress[1]
await reporter.update(0)
await asyncio.sleep(0.5)
finally:
root_ctx.event_dispatcher.unsubscribe(progress_handler)


@catch_unexpected(log)
async def check_agent_lost(root_ctx: RootContext, interval: float) -> None:
try:
Expand Down