Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Ability to cancel workflow execution #16320

Merged
merged 4 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llama-index-core/llama_index/core/workflow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, workflow: "Workflow", stepwise: bool = False) -> None:
self._queues: Dict[str, asyncio.Queue] = {}
self._tasks: Set[asyncio.Task] = set()
self._broker_log: List[Event] = []
self._cancel_flag: asyncio.Event = asyncio.Event()
self._step_flags: Dict[str, asyncio.Event] = {}
self._step_event_holding: Optional[Event] = None
self._step_lock: asyncio.Lock = asyncio.Lock()
Expand Down
4 changes: 4 additions & 0 deletions llama-index-core/llama_index/core/workflow/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ class WorkflowRuntimeError(Exception):

class WorkflowDone(Exception):
pass


class WorkflowCancelledByUser(Exception):
pass
6 changes: 6 additions & 0 deletions llama-index-core/llama_index/core/workflow/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,9 @@ async def run_step(self) -> Optional[Event]:
raise ValueError("Context is not set!")

return retval

async def cancel_run(self) -> None:
"""Method to cancel a Workflow execution."""
if self.ctx:
self.ctx._cancel_flag.set()
await asyncio.sleep(0)
20 changes: 19 additions & 1 deletion llama-index-core/llama_index/core/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

from .decorators import StepConfig, step
from .context import Context
from .events import InputRequiredEvent, HumanResponseEvent, Event, StartEvent, StopEvent
from .events import (
InputRequiredEvent,
HumanResponseEvent,
Event,
StartEvent,
StopEvent,
)
from .errors import *
from .service import ServiceManager
from .utils import (
Expand Down Expand Up @@ -162,6 +168,7 @@ def _start(self, stepwise: bool = False, ctx: Optional[Context] = None) -> Conte
ctx._step_flags = {}
ctx._retval = None
ctx._step_event_holding = None
ctx._cancel_flag.clear()

for name, step_func in self._get_steps().items():
ctx._queues[name] = asyncio.Queue()
Expand Down Expand Up @@ -275,6 +282,17 @@ async def _task(
)
)

# add dedicated cancel task
async def _cancel_workflow_task() -> None:
await ctx._cancel_flag.wait()
raise WorkflowCancelledByUser

ctx._tasks.add(
asyncio.create_task(
_cancel_workflow_task(), name="cancel_workflow_task"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting method! Basically, this allows us to reuse all our existing logic for canceling tasks 💪🏻

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah! And it also I think matches the current working relationship between Worfklow and WorkflowHandler.

)
)

return ctx

def send_event(self, message: Event, step: Optional[str] = None) -> None:
Expand Down
16 changes: 16 additions & 0 deletions llama-index-core/tests/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
WorkflowTimeoutError,
WorkflowValidationError,
WorkflowRuntimeError,
WorkflowCancelledByUser,
)

from .conftest import AnotherTestEvent, LastEvent, OneTestEvent
Expand Down Expand Up @@ -64,6 +65,21 @@ async def test_workflow_run_step(workflow):
assert result == "Workflow completed"


@pytest.mark.asyncio()
async def test_workflow_cancelled_by_user(workflow):
handler = workflow.run(stepwise=True)

event = await handler.run_step()
assert isinstance(event, OneTestEvent)
assert not handler.is_done()
handler.ctx.send_event(event)

await handler.cancel_run()
await asyncio.sleep(0.1) # let workflow get cancelled
assert handler.is_done()
assert type(handler.exception()) == WorkflowCancelledByUser


@pytest.mark.asyncio()
async def test_workflow_run_step_continue_context():
class DummyWorkflow(Workflow):
Expand Down
Loading