Skip to content

Pipeline state dump and load #352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default.
- Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call).
- Added pipeline execution control with state management (`dump_state()` and `load_state()` methods) and partial execution support in the `run()` method (with `until` and `from_` parameters), enabling pipeline state dump and resumption of long-running pipelines, debugging workflows, and incremental processing.

### Fixed

Expand Down
30 changes: 30 additions & 0 deletions docs/source/user_guide_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,33 @@ This will send an `TASK_PROGRESS` event to the pipeline callback.
.. note::

In a future release, the `context_` parameter will be added to the `run` method.


*************************
Pipeline State Management
*************************

Pipelines support state management to enable saving and restoring execution state, which is useful for debugging, resuming long-running pipelines, or incremental processing workflows.

Saving and Loading State
========================

You can save the current state of a pipeline execution using the `dump_state()` method and restore it with `load_state()`. The pipeline also supports partial execution using the `until` and `from_` parameters:

- **`until`**: Stop execution after a specific component completes
- **`from_`**: Start execution from a specific component instead of from the beginning

.. code:: python

# Run pipeline and save state
result = await pipeline.run(..., until="a")
state = pipeline.dump_state(result.run_id)
# The user could save the state to a JSON file

# Resuming pipeline, could be from another run
loaded_run_id = pipeline.load_state(state)
new_result = await pipeline.run(..., from_="b", previous_run_id=loaded_run_id)

.. warning:: State Compatibility

When loading state, the current pipeline must have at least all the components that were present when the state was saved. Additional components are allowed, but missing components will cause a validation error.
48 changes: 43 additions & 5 deletions src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import uuid
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, AsyncGenerator
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional

from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.exceptions import (
Expand All @@ -46,16 +46,31 @@ class Orchestrator:
- finding the next tasks to execute
- building the inputs for each task
- calling the run method on each task
- optionally stopping after a specified component
- optionally starting from a specified component

Once a TaskNode is done, it calls the `on_task_complete` callback
that will save the results, find the next tasks to be executed
(checking that all dependencies are met), and run them.

Partial execution is supported through:
- stop_after: Stop execution after this component completes
- start_from: Start execution from this component instead of roots
"""

def __init__(self, pipeline: Pipeline):
def __init__(
self,
pipeline: Pipeline,
stop_after: Optional[str] = None,
start_from: Optional[str] = None,
previous_run_id: Optional[str] = None,
):
self.pipeline = pipeline
self.event_notifier = EventNotifier(pipeline.callbacks)
self.run_id = str(uuid.uuid4())
self.previous_run_id = previous_run_id # useful for pipeline resumption
self.stop_after = stop_after
self.start_from = start_from

async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
"""Get inputs and run a specific task. Once the task is done,
Expand Down Expand Up @@ -129,7 +144,10 @@ async def on_task_complete(
await self.add_result_for_component(
task.name, res_to_save, is_final=task.is_leaf()
)
# then get the next tasks to be executed
# stop if this is the stop_after node
if self.stop_after and task.name == self.stop_after:
return
# otherwise, get the next tasks to be executed
# and run them in //
await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])

Expand Down Expand Up @@ -252,10 +270,24 @@ async def add_result_for_component(
)

async def get_results_for_component(self, name: str) -> Any:
# when resuming, check previous run_id, otherwise check current run_id
if self.previous_run_id:
return await self.pipeline.store.get_result_for_component(
self.previous_run_id, name
)
return await self.pipeline.store.get_result_for_component(self.run_id, name)

async def get_status_for_component(self, name: str) -> RunStatus:
status = await self.pipeline.store.get_status_for_component(self.run_id, name)
# when resuming, check previous run_id, otherwise check current run_id
if self.previous_run_id:
status = await self.pipeline.store.get_status_for_component(
self.previous_run_id, name
)
else:
status = await self.pipeline.store.get_status_for_component(
self.run_id, name
)

if status is None:
return RunStatus.UNKNOWN
return RunStatus(status)
Expand All @@ -266,7 +298,13 @@ async def run(self, data: dict[str, Any]) -> None:
will handle the task dependencies.
"""
await self.event_notifier.notify_pipeline_started(self.run_id, data)
tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
# start from a specific node if requested, otherwise from roots
if self.start_from:
start_nodes = [self.pipeline.get_node_by_name(self.start_from)]
else:
start_nodes = self.pipeline.roots()

tasks = [self.run_task(root, data) for root in start_nodes]
await asyncio.gather(*tasks)
await self.event_notifier.notify_pipeline_finished(
self.run_id, await self.pipeline.get_final_results(self.run_id)
Expand Down
165 changes: 160 additions & 5 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from collections import defaultdict
from timeit import default_timer
from typing import Any, Optional, AsyncGenerator
from typing import Any, Optional, AsyncGenerator, Dict
import asyncio

from neo4j_graphrag.utils.logging import prettify
Expand Down Expand Up @@ -390,22 +390,34 @@ def validate_parameter_mapping(self) -> None:
self.validate_parameter_mapping_for_task(task)
self.is_validated = True

def validate_input_data(self, data: dict[str, Any]) -> bool:
def validate_input_data(
self, data: dict[str, Any], from_: Optional[str] = None
) -> bool:
"""Performs parameter and data validation before running the pipeline:
- Check parameters defined in the connect method
- Make sure the missing parameters are present in the input `data` dict.

Args:
data (dict[str, Any]): input data to use for validation
(usually from Pipeline.run)
from_ (Optional[str]): If provided, only validate components that will actually execute
starting from this component

Raises:
PipelineDefinitionError if any parameter mapping is invalid or if a
parameter is missing.
"""
if not self.is_validated:
self.validate_parameter_mapping()

# determine which components need validation
components_to_validate = self._get_components_to_validate(from_)

for task in self._nodes.values():
# skip validation for components that won't execute
if task.name not in components_to_validate:
continue

if task.name not in self.param_mapping:
self.validate_parameter_mapping_for_task(task)
missing_params = self.missing_inputs[task.name]
Expand All @@ -417,6 +429,37 @@ def validate_input_data(self, data: dict[str, Any]) -> bool:
)
return True

def _get_components_to_validate(self, from_: Optional[str] = None) -> set[str]:
"""Determine which components need validation based on execution context.

Args:
from_ (Optional[str]): Starting component for execution

Returns:
set[str]: Set of component names that need validation
"""
if from_ is None:
# no from_ specified, validate all components
return set(self._nodes.keys())

# when from_ is specified, only validate components that will actually execute
# this includes the from_ component and all its downstream dependencies
components_to_validate = set()

def add_downstream_components(component_name: str) -> None:
"""Recursively add a component and all its downstream dependencies"""
if component_name in components_to_validate:
return # Already processed
components_to_validate.add(component_name)

# add all components that depend on this one
for edge in self.next_edges(component_name):
add_downstream_components(edge.end)

# start from the specified component and add all downstream
add_downstream_components(from_)
return components_to_validate

def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
"""Make sure that all the parameter mapping for a given task are valid.
Does not consider user input yet.
Expand Down Expand Up @@ -563,19 +606,131 @@ async def event_stream(event: Event) -> None:
if event_queue_getter_task and not event_queue_getter_task.done():
event_queue_getter_task.cancel()

async def run(self, data: dict[str, Any]) -> PipelineResult:
async def run(
self,
data: dict[str, Any],
from_: Optional[str] = None,
until: Optional[str] = None,
previous_run_id: Optional[str] = None,
) -> PipelineResult:
"""Run the pipeline, optionally from a specific component or until a specific component.

Args:
data (dict[str, Any]): The input data for the pipeline
from_ (str | None, optional): If provided, start execution from this component. Defaults to None.
until (str | None, optional): If provided, stop execution after this component. Defaults to None.
previous_run_id (str | None, optional): If provided, resume from this previous run_id. Defaults to None.

Returns:
PipelineResult: The result of the pipeline execution
"""
logger.debug("PIPELINE START")
start_time = default_timer()
self.invalidate()
self.validate_input_data(data)
orchestrator = Orchestrator(self)
self.validate_input_data(data, from_)

# create orchestrator with appropriate start_from and stop_after params
orchestrator = Orchestrator(
self, stop_after=until, start_from=from_, previous_run_id=previous_run_id
)

logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
await orchestrator.run(data)

end_time = default_timer()
logger.debug(
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s"
)

return PipelineResult(
run_id=orchestrator.run_id,
result=await self.get_final_results(orchestrator.run_id),
)

def dump_state(self, run_id: str) -> Dict[str, Any]:
"""Dump the current state of the pipeline to a serializable dictionary.

Args:
run_id (str): The run_id to dump state for

Returns:
Dict[str, Any]: A serializable dictionary containing the pipeline state

Raises:
ValueError: If run_id is None or empty
"""
if not run_id:
raise ValueError("run_id cannot be None or empty")

pipeline_state: Dict[str, Any] = {
"run_id": run_id,
"store": self.store.dump(run_id),
}
return pipeline_state

def load_state(self, state: Dict[str, Any]) -> str:
"""Load pipeline state from a serialized dictionary.

Args:
state (dict[str, Any]): Previously serialized pipeline state

Returns:
str: The run_id from the loaded state

Raises:
ValueError: If the state is invalid or incompatible with current pipeline
"""
if "run_id" not in state:
raise ValueError("Invalid state: missing run_id")

run_id: str = state["run_id"]

# validate pipeline compatibility
self._validate_state_compatibility(state)

# load store data
if "store" in state:
self.store.load(state["store"])

return run_id

def _validate_state_compatibility(self, state: Dict[str, Any]) -> None:
"""Validate that the loaded state is compatible with the current pipeline.

This checks that the components defined in the pipeline match those
that were present when the state was saved.

Args:
state (dict[str, Any]): The state to validate

Raises:
ValueError: If the state is incompatible with the current pipeline
"""
if "store" not in state:
return # no store data to validate

store_data = state["store"]
if not store_data:
return # empty store, nothing to validate

# extract component names from the store keys
# keys are in format: "run_id:component_name" or "run_id:component_name:suffix"
stored_components = set()
for key in store_data.keys():
parts = key.split(":")
if len(parts) >= 2:
component_name = parts[1]
stored_components.add(component_name)

# get current pipeline component names
current_components = set(self._nodes.keys())

# check if stored components are a subset of current components
# this allows for the pipeline to have additional components, but not missing ones
missing_components = stored_components - current_components
if missing_components:
raise ValueError(
f"State is incompatible with current pipeline. "
f"Missing components: {sorted(missing_components)}. "
f"Current pipeline components: {sorted(current_components)}"
)
Loading
Loading