Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Checkpointing with Workflows #17006

Merged
merged 69 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
347be55
initial design of WorkflowProfile and auxiliary classes
nerdai Nov 19, 2024
9493f83
add more scaffolding
nerdai Nov 19, 2024
2df3bdc
rework into Workflow
nerdai Nov 20, 2024
5645b9e
move checkpointing outside of .send_event()
nerdai Nov 20, 2024
1af980b
unit test for create_checkpoint method
nerdai Nov 20, 2024
26f0bfd
rename incoming_event to input_event
nerdai Nov 20, 2024
166715a
add unit test for primitive filtering
nerdai Nov 20, 2024
a638027
add assertion for raising when no filter specified
nerdai Nov 20, 2024
da06c50
fix mypy issue
nerdai Nov 20, 2024
327bee6
more mypy issues
nerdai Nov 20, 2024
4d8db5a
add ability to run_from a given checkpoint
nerdai Nov 21, 2024
ae8673d
add sample nb
nerdai Nov 21, 2024
cc7e7b5
mypy nit
nerdai Nov 21, 2024
0c09793
cr
nerdai Nov 21, 2024
93f7b11
cr _broker_log to _checkpoints
nerdai Nov 21, 2024
6fb2657
update broker_log entry in seriallization to checkpoints
nerdai Nov 21, 2024
569fdb8
add ability to turn off checkpoints
nerdai Nov 21, 2024
bb63157
add disabling checkpointing to notebook
nerdai Nov 21, 2024
7b0790d
add unit test for Workflow.run_from
nerdai Nov 21, 2024
4c2518b
fix broken tests part 1
nerdai Nov 21, 2024
eb4e3dc
add context_serializer param at init and private attr to Workflow, wh…
nerdai Nov 21, 2024
7d485d8
rename to checkpoint_serializer
nerdai Nov 21, 2024
ef73fbf
change store_checkpoints default to False
nerdai Nov 21, 2024
dba9a65
update nb
nerdai Nov 21, 2024
cfdefa5
move checkpointing to Workflow
nerdai Nov 22, 2024
dc7a5ea
move unit tests to test_workflow
nerdai Nov 22, 2024
e6b7757
revert Context back
nerdai Nov 22, 2024
7c75c81
lint, fix tests
nerdai Nov 22, 2024
c8a5048
improve error handling
nerdai Nov 22, 2024
7c932f6
cr
nerdai Nov 22, 2024
266b41d
add serializers to module top-level import
nerdai Nov 22, 2024
8d81fee
update nb and docstring for run_from
nerdai Nov 22, 2024
260c991
intro workflow checkpointer
nerdai Nov 22, 2024
42ec8ab
revert events back
nerdai Nov 22, 2024
79695bd
add WorkflowCheckpointer
nerdai Nov 22, 2024
000c0d2
fix test
nerdai Nov 23, 2024
48c9cd5
asset not done()
nerdai Nov 23, 2024
ce15053
mypy
nerdai Nov 23, 2024
3afb7fe
mypy
nerdai Nov 23, 2024
e89dbd7
rename to checkpoint_callback
nerdai Nov 23, 2024
5537e30
rename factory method
nerdai Nov 23, 2024
4ae44b7
proper placement of checkpoint_callback
nerdai Nov 23, 2024
0dfd23c
wip
nerdai Nov 23, 2024
5e5c05f
revert errors.py
nerdai Nov 23, 2024
87954f2
add unit test for checkpointing with stepwise
nerdai Nov 23, 2024
52cdd44
start implementing ability to control checkpointed steps
nerdai Nov 23, 2024
8be9c2d
start implementing ability to control checkpointed steps
nerdai Nov 23, 2024
615496d
impl enable disable for checkpoints
nerdai Nov 23, 2024
598c02c
add unit test for disable/enable
nerdai Nov 23, 2024
56fe232
refactor enable/disable impl
nerdai Nov 23, 2024
abdf6d5
pin pydantic for nebius
nerdai Nov 23, 2024
cdae078
parallel execution chekcpointing
nerdai Nov 25, 2024
ec1dc82
wip
nerdai Nov 25, 2024
fbb3335
wip
nerdai Nov 25, 2024
37720a3
renames
nerdai Nov 25, 2024
7a1fae4
update parallel execution nb
nerdai Nov 25, 2024
3677390
test to check works across async loops
nerdai Nov 25, 2024
be86b68
simplify checkpointing nb
nerdai Nov 25, 2024
ffefa05
show how to enable/disable checkpoints
nerdai Nov 25, 2024
d6fb4b9
rm serializer from workflow init
nerdai Nov 25, 2024
31da494
cr
nerdai Nov 26, 2024
7594455
rename checkpoint_config to enabled_checkpoints
nerdai Nov 26, 2024
19c1165
doc strings
nerdai Nov 26, 2024
2bba5b8
cr
nerdai Nov 26, 2024
0b4577e
add run_id to Handler, move its creation to within _startc
nerdai Nov 26, 2024
baf0ae8
add section for checkpointing in module guides
nerdai Nov 26, 2024
0e5a94b
add Checkpointing to observability section
nerdai Nov 26, 2024
328b5a1
class string for WorkflowCheckpointer
nerdai Nov 26, 2024
6d4103e
wip
nerdai Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
527 changes: 527 additions & 0 deletions docs/docs/examples/workflow/checkpointing_workflows.ipynb

Large diffs are not rendered by default.

314 changes: 301 additions & 13 deletions docs/docs/examples/workflow/parallel_execution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install llama-index-core llama-index-utils-workflow"
"# %pip install llama-index-core llama-index-utils-workflow -q"
]
},
{
Expand Down Expand Up @@ -110,13 +110,14 @@
" data_list = [\"A\", \"B\", \"C\"]\n",
" await ctx.set(\"num_to_collect\", len(data_list))\n",
" for item in data_list:\n",
" self.send_event(ProcessEvent(data=item))\n",
" ctx.send_event(ProcessEvent(data=item))\n",
" return None\n",
"\n",
" @step\n",
" @step(num_workers=1)\n",
" async def process_data(self, ev: ProcessEvent) -> ResultEvent:\n",
" # Simulate some time-consuming processing\n",
" await asyncio.sleep(random.randint(1, 2))\n",
" processing_time = 2 + random.random()\n",
" await asyncio.sleep(processing_time)\n",
" result = f\"Processed: {ev.data}\"\n",
" print(f\"Completed processing: {ev.data}\")\n",
" return ResultEvent(result=result)\n",
Expand All @@ -140,13 +141,14 @@
" data_list = [\"A\", \"B\", \"C\"]\n",
" await ctx.set(\"num_to_collect\", len(data_list))\n",
" for item in data_list:\n",
" self.send_event(ProcessEvent(data=item))\n",
" ctx.send_event(ProcessEvent(data=item))\n",
" return None\n",
"\n",
" @step(num_workers=3)\n",
" async def process_data(self, ev: ProcessEvent) -> ResultEvent:\n",
" # Simulate some time-consuming processing\n",
" await asyncio.sleep(random.randint(1, 2))\n",
" processing_time = 2 + random.random()\n",
" await asyncio.sleep(processing_time)\n",
" result = f\"Processed: {ev.data}\"\n",
" print(f\"Completed processing: {ev.data}\")\n",
" return ResultEvent(result=result)\n",
Expand Down Expand Up @@ -194,14 +196,14 @@
"Completed processing: B\n",
"Completed processing: C\n",
"Workflow result: Processed: A, Processed: B, Processed: C\n",
"Time taken: 4.008663654327393 seconds\n",
"Time taken: 7.439495086669922 seconds\n",
"------------------------------\n",
"Start a parallel workflow with setting num_workers in the step of process_data\n",
"Completed processing: C\n",
"Completed processing: A\n",
"Completed processing: B\n",
"Workflow result: Processed: C, Processed: A, Processed: B\n",
"Time taken: 2.0040180683135986 seconds\n"
"Time taken: 2.5881590843200684 seconds\n"
]
}
],
Expand Down Expand Up @@ -238,7 +240,7 @@
"source": [
"# Note\n",
"\n",
"- Without setting num_workers, it might take 3 to 6 seconds. By setting num_workers, the processing occurs in parallel, handling 3 items at a time, and only takes 2 seconds.\n",
"- Without setting `num_workers=1`, it might take a total of 6-9 seconds. By setting `num_workers=3`, the processing occurs in parallel, handling 3 items at a time, and only takes 2-3 seconds total.\n",
"- In ParallelWorkflow, the order of the completed results may differ from the input order, depending on the completion time of the tasks.\n"
]
},
Expand All @@ -248,6 +250,283 @@
"source": [
"This example demonstrates the execution speed with and without using num_workers, and how to implement parallel processing in a workflow. By setting num_workers, we can control the degree of parallelism, which is very useful for scenarios that need to balance performance and resource usage."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Checkpointing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Checkpointing a parallel execution Workflow like the one defined above is also possible. To do so, we must wrap the `Workflow` with a `WorkflowCheckpointer` object and perfrom the runs with these instances. During the execution of the workflow, checkpoints are stored in this wrapper object and can be used for inspection and as starting points for run executions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.core.workflow.checkpointer import WorkflowCheckpointer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Completed processing: C\n",
"Completed processing: A\n",
"Completed processing: B\n"
]
},
{
"data": {
"text/plain": [
"'Processed: C, Processed: A, Processed: B'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"wflow_ckptr = WorkflowCheckpointer(workflow=parallel_workflow)\n",
"handler = wflow_ckptr.run()\n",
"await handler"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Checkpoints for the above run are stored in the `WorkflowCheckpointer.checkpoints` Dict attribute."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run: 90812bec-b571-4513-8ad5-aa957ad7d4fb has ['process_data', 'process_data', 'process_data', 'combine_results']\n"
]
}
],
"source": [
"for run_id, ckpts in wflow_ckptr.checkpoints.items():\n",
" print(f\"Run: {run_id} has {[c.last_completed_step for c in ckpts]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can run from any of the checkpoints stored, using `WorkflowCheckpointer.run_from(checkpoint=...)` method. Let's take the first checkpoint that was stored after the first completion of \"process_data\" and run from it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Completed processing: B\n",
"Completed processing: A\n"
]
},
{
"data": {
"text/plain": [
"'Processed: C, Processed: B, Processed: A'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ckpt = wflow_ckptr.checkpoints[run_id][0]\n",
"ckpt\n",
"handler = wflow_ckptr.run_from(ckpt)\n",
"await handler"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Invoking a `run_from` or `run` will create a new run entry in the `checkpoints` attribute. In the latest run from the specified checkpoint, we can see that only two more \"process_data\" steps and the final \"combine_results\" steps were left to be completed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run: 90812bec-b571-4513-8ad5-aa957ad7d4fb has ['process_data', 'process_data', 'process_data', 'combine_results']\n",
"Run: 4e1d24cd-c672-4ed1-bb5b-b9f1a252abed has ['process_data', 'process_data', 'combine_results']\n"
]
}
],
"source": [
"for run_id, ckpts in wflow_ckptr.checkpoints.items():\n",
" print(f\"Run: {run_id} has {[c.last_completed_step for c in ckpts]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, if we use the checkpoint associated with the second completion of \"process_data\" of the same initial run as the starting point, then we should see a new entry that only has two steps: \"process_data\" and \"combine_results\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'90812bec-b571-4513-8ad5-aa957ad7d4fb'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# get the run_id of the first initial run\n",
"first_run_id = next(iter(wflow_ckptr.checkpoints.keys()))\n",
"first_run_id"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Completed processing: B\n"
]
},
{
"data": {
"text/plain": [
"'Processed: C, Processed: A, Processed: B'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ckpt = wflow_ckptr.checkpoints[first_run_id][\n",
" 1\n",
"] # checkpoint after the second \"process_data\" step\n",
"ckpt\n",
"handler = wflow_ckptr.run_from(ckpt)\n",
"await handler"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run: 90812bec-b571-4513-8ad5-aa957ad7d4fb has ['process_data', 'process_data', 'process_data', 'combine_results']\n",
"Run: 4e1d24cd-c672-4ed1-bb5b-b9f1a252abed has ['process_data', 'process_data', 'combine_results']\n",
"Run: e4f94fcd-9b78-4e28-8981-e0232d068f6e has ['process_data', 'combine_results']\n"
]
}
],
"source": [
"for run_id, ckpts in wflow_ckptr.checkpoints.items():\n",
" print(f\"Run: {run_id} has {[c.last_completed_step for c in ckpts]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly, if we start with the checkpoint for the third completion of \"process_data\" of the initial run, then we should only see the final \"combine_results\" step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Processed: C, Processed: A, Processed: B'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ckpt = wflow_ckptr.checkpoints[first_run_id][\n",
" 2\n",
"] # checkpoint after the third \"process_data\" step\n",
"ckpt\n",
"handler = wflow_ckptr.run_from(ckpt)\n",
"await handler"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run: 90812bec-b571-4513-8ad5-aa957ad7d4fb has ['process_data', 'process_data', 'process_data', 'combine_results']\n",
"Run: 4e1d24cd-c672-4ed1-bb5b-b9f1a252abed has ['process_data', 'process_data', 'combine_results']\n",
"Run: e4f94fcd-9b78-4e28-8981-e0232d068f6e has ['process_data', 'combine_results']\n",
"Run: c498a1a0-cf4c-4d80-a1e2-a175bb90b66d has ['combine_results']\n"
]
}
],
"source": [
"for run_id, ckpts in wflow_ckptr.checkpoints.items():\n",
" print(f\"Run: {run_id} has {[c.last_completed_step for c in ckpts]}\")"
]
}
],
"metadata": {
Expand All @@ -256,13 +535,22 @@
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
"display_name": "llama-index-core",
"language": "python",
"name": "llama-index-core"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 0
"nbformat_minor": 4
}
Loading
Loading