Skip to content

Interruption issue #11

Closed
Closed
@hung-phan

Description

@hung-phan

Hi there,

I am trying to use the library but encounter some issues. The interruption workflow i have is working well with memory checkpointer, but when i switch to AsyncRedisCheck i don't see interruption message get return by langgraph any more.

My code is sth like

def get_checkpointer(checkpointer_type: CheckpointerType) -> Callable[[], Coroutine[None, None, Checkpointer]]:
    return lambda: _get_checkpointer(checkpointer_type)


def _get_redis_client():
    return Redis(...)


@cached(cache=Cache.MEMORY)
async def _get_checkpointer(checkpointer_type: CheckpointerType) -> Checkpointer:
    match checkpointer_type:
        case CheckpointerType.SHALLOW:
            async with AsyncShallowRedisSaver(redis_client=_get_redis_client()) as checkpointer:
                await checkpointer.asetup()
                return checkpointer

        case CheckpointerType.DEEP:
            async with AsyncRedisSaver(redis_client=_get_redis_client()) as checkpointer:
                await checkpointer.asetup()
                return checkpointer

        case CheckpointerType.MEMORY:
            if not is_local():
                raise ValueError("Memory checkpointer is not supported in production environment due to memory leak.")

            return MemorySaver()

@dataclass
class WorkflowBuilder:
    name: str
    workflow_state_class: type[WorkflowState] | None = None
    checkpointer_func: Callable[..., Coroutine[None, None, Checkpointer]] | None = None

    async def build(self) -> "Workflow":
        self.__validate_nodes()

        builder = StateGraph(self.workflow_state_class)

        self.__add_nodes(builder)

        try:
            self.__add_start_edges(builder)
            self.__add_intra_edges(builder)
            self.__add_end_edges(builder)

            return Workflow(
                name=self.name,
                compiled_workflow=builder.compile(
                    checkpointer=(await self.checkpointer_func() if callable(self.checkpointer_func) else MemorySaver())
                ),
            )

@dataclass
class Workflow:
    compiled_workflow: CompiledStateGraph

    async def stream_using_workflow_state(
        self,
        workflow_state: MutableMapping[str, Any],
        *,
        config: RunnableConfig | None = None
    ) -> AsyncGenerator[WorkflowStreamEvent, Any]:
        config = config or RunnableConfig()
        config.setdefault("configurable", {}).setdefault("thread_id", str(ULID()))

        thread_id: str = config["configurable"]["thread_id"]
        input_data: Command = Command(update=workflow_state)

        try:
            while True:
                async for event in self.compiled_workflow.astream(
                    input=input_data,
                    config=config,
                    stream_mode=["custom", "updates"],
                    subgraphs=include_nested_workflows,
                ):
                  print(event)

Then i build a simple graph, but some how it doesn't return anything when i start interrupt

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions