Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Support max concurrent workflow_instance.run() executions. #16215

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llama-index-core/llama_index/core/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
disable_validation: bool = False,
verbose: bool = False,
service_manager: Optional[ServiceManager] = None,
num_concurrent_runs: Optional[int] = None,
) -> None:
"""Create an instance of the workflow.

Expand All @@ -60,11 +61,17 @@ def __init__(
verbose: whether or not the workflow should print additional informative messages during execution.
service_manager: The instance of the `ServiceManager` used to make nested workflows available to this
workflow instance. The default value is the best choice unless you're customizing the workflow runtime.
num_concurrent_runs: maximum number of .run() executions occurring simultaneously. If set to `None`, there
is no limit to this number.
"""
# Configuration
self._timeout = timeout
self._verbose = verbose
self._disable_validation = disable_validation
self._num_concurrent_runs = num_concurrent_runs
self._sem = (
asyncio.Semaphore(num_concurrent_runs) if num_concurrent_runs else None
)
# Broker machinery
self._contexts: Set[Context] = set()
self._stepwise_context: Optional[Context] = None
Expand Down Expand Up @@ -297,6 +304,8 @@ def run(
result = WorkflowHandler(ctx=ctx)

async def _run_workflow() -> None:
if self._sem:
await self._sem.acquire()
try:
# Send the first event
ctx.send_event(StartEvent(**kwargs))
Expand Down Expand Up @@ -335,6 +344,9 @@ async def _run_workflow() -> None:
result.set_result(ctx._retval)
except Exception as e:
result.set_exception(e)
finally:
if self._sem:
self._sem.release()

asyncio.create_task(_run_workflow())
return result
Expand Down
85 changes: 85 additions & 0 deletions llama-index-core/tests/workflow/test_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import time
from unittest import mock
from typing import Type

import pytest

Expand Down Expand Up @@ -484,3 +485,87 @@ async def step2(self, ev: HumanResponseEvent) -> StopEvent:

final_result = await handler
assert final_result == "42"


class DummyWorkflowForConcurrentRunsTest(Workflow):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._lock = asyncio.Lock()
self.num_active_runs = 0

@step
async def step_one(self, ev: StartEvent) -> StopEvent:
run_num = ev.get("run_num")
async with self._lock:
self.num_active_runs += 1
await asyncio.sleep(0.1)
return StopEvent(result=f"Run {run_num}: Done")

@step
async def _done(self, ctx: Context, ev: StopEvent) -> None:
async with self._lock:
self.num_active_runs -= 1
await super()._done(ctx, ev)

async def get_active_runs(self):
async with self._lock:
return self.num_active_runs


class NumConcurrentRunsException(Exception):
pass


@pytest.mark.parametrize(
(
"workflow",
"desired_max_concurrent_runs",
"expected_exception",
),
[
(
DummyWorkflowForConcurrentRunsTest(num_concurrent_runs=1),
1,
type(None),
),
# This workflow is not protected, and so NumConcurrentRunsException is raised
(
DummyWorkflowForConcurrentRunsTest(),
1,
NumConcurrentRunsException,
),
],
)
async def test_workflow_run_num_concurrent(
workflow: DummyWorkflowForConcurrentRunsTest,
desired_max_concurrent_runs: int,
expected_exception: Type,
):
async def _poll_workflow(
wf: DummyWorkflowForConcurrentRunsTest, desired_max_concurrent_runs: int
) -> None:
"""Check that number of concurrent runs is less than desired max amount."""
for _ in range(100):
num_active_runs = await wf.get_active_runs()
if num_active_runs > desired_max_concurrent_runs:
raise NumConcurrentRunsException
await asyncio.sleep(0.01)

poll_task = asyncio.create_task(
_poll_workflow(
wf=workflow, desired_max_concurrent_runs=desired_max_concurrent_runs
),
)

tasks = []
for ix in range(1, 5):
tasks.append(workflow.run(run_num=ix))

results = await asyncio.gather(*tasks)

if not poll_task.done():
await poll_task
e = poll_task.exception()

assert type(e) == expected_exception
assert results == [f"Run {ix}: Done" for ix in range(1, 5)]
Loading