Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-factor WorkflowHandler.run_step() so user manually emits Event to start next step in worfklow #16277

Merged
merged 11 commits into from
Oct 1, 2024
12 changes: 9 additions & 3 deletions docs/docs/understanding/workflows/observability.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,19 @@ In a notebook environment it can be helpful to run a workflow step by step. You
w = ConcurrentFlow(timeout=10, verbose=True)
handler = w.run()

async for _ in handler.run_step():
# inspect context
while not handler.is_done():
# run_step returns the step's output event
ev = await handler.run_step()
# can make modifications to the results before dispatching the event
# val = ev.get("some_key")
# ev.set("some_key", new_val)
# can also inspect context
# val = await handler.ctx.get("key")
handler.ctx.send_event(ev)
continue

# get the result
result = await handler
result = handler.result()
```

You can call `run_step` multiple times to step through the workflow one step at a time.
Expand Down
8 changes: 8 additions & 0 deletions llama-index-core/llama_index/core/workflow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def __init__(self, workflow: "Workflow", stepwise: bool = False) -> None:
self._tasks: Set[asyncio.Task] = set()
self._broker_log: List[Event] = []
self._step_flags: Dict[str, asyncio.Event] = {}
self._step_event_holding: Optional[Event] = None
self._step_lock: asyncio.Lock = asyncio.Lock()
self._step_condition: asyncio.Condition = asyncio.Condition(
lock=self._step_lock
)
self._step_event_written: asyncio.Condition = asyncio.Condition(
lock=self._step_lock
)
self._accepted_events: List[Tuple[str, str]] = []
self._retval: Any = None
# Streaming machinery
Expand Down
97 changes: 64 additions & 33 deletions llama-index-core/llama_index/core/workflow/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,44 +31,75 @@ async def stream_events(self) -> AsyncGenerator[Event, None]:
if type(ev) is StopEvent:
break

async def run_step(self) -> Optional[Any]:
async def run_step(self) -> Optional[Event]:
"""Runs the next workflow step and returns the output Event.

If return is None, then the workflow is considered done.

Examples:
```python
handler = workflow.run(stepwise=True)
while not handler.is_done():
ev = await handler.run_step()
handler.ctx.send_event(ev)

result = handler.result()
print(result)
```
"""
# since event is sent before calling this method, we need to unblock the event loop
await asyncio.sleep(0)

if self.ctx and not self.ctx.stepwise:
raise ValueError("Stepwise context is required to run stepwise.")

if self.ctx:
# Unblock all pending steps
for flag in self.ctx._step_flags.values():
flag.set()

# Yield back control to the event loop to give an unblocked step
# the chance to run (we won't actually sleep here).
await asyncio.sleep(0)

# See if we're done, or if a step raised any error
we_done = False
exception_raised = None
for t in self.ctx._tasks:
# Check if we're done
if not t.done():
continue

we_done = True
e = t.exception()
if type(e) != WorkflowDone:
exception_raised = e

retval = None
if we_done:
# Remove any reference to the tasks
try:
# Unblock all pending steps
for flag in self.ctx._step_flags.values():
flag.set()

# Yield back control to the event loop to give an unblocked step
# the chance to run (we won't actually sleep here).
await asyncio.sleep(0)

# check if we're done, or if a step raised error
we_done = False
exception_raised = None
retval = None
for t in self.ctx._tasks:
t.cancel()
await asyncio.sleep(0)
retval = self.ctx.get_result()

self.set_result(retval)

if exception_raised:
raise exception_raised
# Check if we're done
if not t.done():
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, is this check actually needed? If a StopEvent is emitted, we don't care if other tasks are still running right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I had to modify this slightly as it wasn't handling any errors raised in a step correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes make sense!


we_done = True
e = t.exception()
if type(e) != WorkflowDone:
exception_raised = e

if we_done:
# Remove any reference to the tasks
for t in self.ctx._tasks:
t.cancel()
await asyncio.sleep(0)

if exception_raised:
raise exception_raised

self.set_result(self.ctx.get_result())
else: # continue with running next step
# notify unblocked task that we're ready to accept next event
async with self.ctx._step_condition:
self.ctx._step_condition.notify()

# Wait to be notified that the new_ev has been written
async with self.ctx._step_event_written:
await self.ctx._step_event_written.wait()
retval = self.ctx._step_event_holding
except Exception as e:
if not self.is_done(): # Avoid InvalidStateError edge case
self.set_exception(e)
raise
else:
raise ValueError("Context is not set!")

Expand Down
62 changes: 8 additions & 54 deletions llama-index-core/llama_index/core/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def _start(self, stepwise: bool = False, ctx: Optional[Context] = None) -> Conte
ctx._queues = {}
ctx._step_flags = {}
ctx._retval = None
ctx._step_event_holding = None

for name, step_func in self._get_steps().items():
ctx._queues[name] = asyncio.Queue()
Expand Down Expand Up @@ -258,7 +259,13 @@ async def _task(
elif isinstance(new_ev, InputRequiredEvent):
ctx.write_event_to_stream(new_ev)
else:
ctx.send_event(new_ev)
if stepwise:
async with ctx._step_condition:
await ctx._step_condition.wait()
ctx._step_event_holding = new_ev
ctx._step_event_written.notify() # shares same lock
else:
ctx.send_event(new_ev)

for _ in range(step_config.num_workers):
ctx._tasks.add(
Expand Down Expand Up @@ -351,59 +358,6 @@ async def _run_workflow() -> None:
asyncio.create_task(_run_workflow())
return result

@dispatcher.span
async def run_step(self, **kwargs: Any) -> Optional[Any]:
"""Runs the workflow stepwise until completion."""
warnings.warn(
"run_step() is deprecated, use `workflow.run(stepwise=True)` instead.\n"
"handler = workflow.run(stepwise=True)\n"
"while not handler.is_done():\n"
" result = await handler.run_step()\n"
" print(result)\n"
)

# Check if we need to start a new session
if self._stepwise_context is None:
self._validate()
self._stepwise_context = self._start(stepwise=True)
# Run the first step
self._stepwise_context.send_event(StartEvent(**kwargs))

# Unblock all pending steps
for flag in self._stepwise_context._step_flags.values():
flag.set()

# Yield back control to the event loop to give an unblocked step
# the chance to run (we won't actually sleep here).
await asyncio.sleep(0)

# See if we're done, or if a step raised any error
we_done = False
exception_raised = None
for t in self._stepwise_context._tasks:
# Check if we're done
if not t.done():
continue

we_done = True
e = t.exception()
if type(e) != WorkflowDone:
exception_raised = e

retval = None
if we_done:
# Remove any reference to the tasks
for t in self._stepwise_context._tasks:
t.cancel()
await asyncio.sleep(0)
retval = self._stepwise_context._retval
self._stepwise_context = None

if exception_raised:
raise exception_raised

return retval

def is_done(self) -> bool:
"""Checks if the workflow is done."""
return self._stepwise_context is None
Expand Down
49 changes: 13 additions & 36 deletions llama-index-core/tests/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,54 +37,31 @@ async def test_workflow_run(workflow):
assert result == "Workflow completed"


@pytest.mark.asyncio()
async def test_deprecated_workflow_run_step(workflow):
workflow._verbose = True

# First step
result = await workflow.run_step()
assert result is None
assert not workflow.is_done()

# Second step
result = await workflow.run_step()
assert result is None
assert not workflow.is_done()

# Final step
result = await workflow.run_step()
assert not workflow.is_done()
assert result is None

# Cleanup step
result = await workflow.run_step()
assert result == "Workflow completed"
assert workflow.is_done()
Comment on lines -40 to -62
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably delete this test and delete Workflow.run_step() now? To support this with the re-factor then we could either:

  1. Maintain an older version of _start() that doesn't contain the new stepwise logic and point this Workflow.run_step() to it

  2. Update the Worfklow.run_step() to have a similar refactor to WorkflowHandler.run_step().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just delete run_step() on the workflow object? (Or at least, delete the implementation and point users towards the updated syntax?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just delete it at this point tbh. I think if we point it to the updated syntax means that we're changing how this Workflow.run_step() behaves and not sure if we should do that since its already deprecated.

If you and @masci are fine with updating the logic of this and still marking it as deprecated then I'd be happy to adjust accordingly.

Copy link
Contributor Author

@nerdai nerdai Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I deleted Workflow.run_step. I don't think it was worth it to update to point to new code as it would change the functionality of this already deprecated method drastically. As such, keeping it would not really provide any benefits to users who were still using this deprecated method.



@pytest.mark.asyncio()
async def test_workflow_run_step(workflow):
handler = workflow.run(stepwise=True)

result = await handler.run_step()
assert result is None
event = await handler.run_step()
assert isinstance(event, OneTestEvent)
assert not handler.is_done()
handler.ctx.send_event(event)

result = await handler.run_step()
assert result is None
event = await handler.run_step()
assert isinstance(event, LastEvent)
assert not handler.is_done()
handler.ctx.send_event(event)

result = await handler.run_step()
assert result is None
event = await handler.run_step()
assert isinstance(event, StopEvent)
assert not handler.is_done()
handler.ctx.send_event(event)

result = await handler.run_step()
assert result is None
assert not handler.is_done()
event = await handler.run_step()
assert event is None

result = await handler.run_step()
assert result == "Workflow completed"
result = await handler
assert handler.is_done()
assert result == "Workflow completed"


@pytest.mark.asyncio()
Expand Down
Loading