Skip to content

Commit 6e320fa

Browse files
committed
Add tests for all LangGraph samples and fix serialization bugs
Tests: - Added tests for all LangGraph samples (hello_world, approval_workflow, react_agent, supervisor, agentic_rag, deep_research, plan_and_execute, reflection) - Added conftest.py with shared fixtures (clear_registry, requires_openai) - Tests requiring OpenAI API are skipped when OPENAI_API_KEY is not set Bug fixes: - Fixed serialization bugs in reflection and plan_and_execute samples - After Temporal serialization, Pydantic models become dicts - Added helper functions to handle both object and dict access for: - Critique objects in reflection sample - Plan, PlanStep, StepResult objects in plan_and_execute sample
1 parent 9b17b34 commit 6e320fa

File tree

11 files changed

+527
-26
lines changed

11 files changed

+527
-26
lines changed

langgraph_samples/plan_and_execute/graph.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,30 @@ class PlanExecuteState(TypedDict):
8484
needs_replan: bool
8585

8686

87+
# Helper functions to handle dict/object access after Temporal serialization
88+
def _get_plan_steps(plan: Plan | dict[str, Any]) -> list[Any]:
89+
"""Get steps from Plan, handling both object and dict forms."""
90+
if isinstance(plan, dict):
91+
return plan.get("steps", [])
92+
return list(plan.steps)
93+
94+
95+
def _get_step_attr(
96+
step: PlanStep | dict[str, Any], attr: str, default: Any = ""
97+
) -> Any:
98+
"""Get attribute from PlanStep, handling both object and dict forms."""
99+
if isinstance(step, dict):
100+
return step.get(attr, default)
101+
return getattr(step, attr, default)
102+
103+
104+
def _get_result_success(result: StepResult | dict[str, Any]) -> bool:
105+
"""Get success from StepResult, handling both object and dict forms."""
106+
if isinstance(result, dict):
107+
return result.get("success", False)
108+
return result.success
109+
110+
87111
# Define tools for the executor agent
88112
@tool
89113
def calculate(expression: str) -> str:
@@ -220,21 +244,25 @@ def execute_step(state: PlanExecuteState) -> dict[str, Any]:
220244
plan = state.get("plan")
221245
current_step = state.get("current_step", 0)
222246

223-
if not plan or current_step >= len(plan.steps):
247+
steps = _get_plan_steps(plan) if plan else []
248+
if not plan or current_step >= len(steps):
224249
return {"needs_replan": False}
225250

226-
step = plan.steps[current_step]
251+
step = steps[current_step]
252+
step_number = _get_step_attr(step, "step_number", 0)
253+
step_desc = _get_step_attr(step, "description", "")
254+
tool_hint = _get_step_attr(step, "tool_hint", "")
227255

228256
# Run the executor agent for this step
229257
result = executor_agent.invoke(
230258
{
231259
"messages": [
232260
SystemMessage(
233-
content=f"You are executing step {step.step_number} of a plan. "
234-
f"Complete this step: {step.description}\n"
235-
f"Suggested tool: {step.tool_hint}"
261+
content=f"You are executing step {step_number} of a plan. "
262+
f"Complete this step: {step_desc}\n"
263+
f"Suggested tool: {tool_hint}"
236264
),
237-
HumanMessage(content=step.description),
265+
HumanMessage(content=step_desc),
238266
]
239267
}
240268
)
@@ -248,8 +276,8 @@ def execute_step(state: PlanExecuteState) -> dict[str, Any]:
248276
)
249277

250278
step_result = StepResult(
251-
step_number=step.step_number,
252-
description=step.description,
279+
step_number=step_number,
280+
description=step_desc,
253281
result=result_content,
254282
success=True,
255283
)
@@ -260,7 +288,7 @@ def execute_step(state: PlanExecuteState) -> dict[str, Any]:
260288
"messages": [
261289
{
262290
"role": "assistant",
263-
"content": f"Step {step.step_number} completed: {result_content[:200]}...",
291+
"content": f"Step {step_number} completed: {result_content[:200]}...",
264292
}
265293
],
266294
}
@@ -278,17 +306,18 @@ def evaluate_progress(state: PlanExecuteState) -> dict[str, Any]:
278306
return {"needs_replan": True}
279307

280308
# Check if all steps completed
281-
all_complete = current_step >= len(plan.steps)
309+
steps = _get_plan_steps(plan)
310+
all_complete = current_step >= len(steps)
282311

283312
# Check if any step failed
284-
failed_steps = [r for r in step_results if not r.success]
313+
failed_steps = [r for r in step_results if not _get_result_success(r)]
285314

286315
return {
287316
"needs_replan": len(failed_steps) > 0,
288317
"messages": [
289318
{
290319
"role": "assistant",
291-
"content": f"Progress: {current_step}/{len(plan.steps)} steps complete. "
320+
"content": f"Progress: {current_step}/{len(steps)} steps complete. "
292321
f"Failures: {len(failed_steps)}",
293322
}
294323
]
@@ -312,7 +341,7 @@ def should_continue(
312341
plan = state.get("plan")
313342
current_step = state.get("current_step", 0)
314343

315-
if plan and current_step < len(plan.steps):
344+
if plan and current_step < len(_get_plan_steps(plan)):
316345
return "execute"
317346

318347
return "respond"

langgraph_samples/reflection/graph.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,13 @@ def should_revise(state: ReflectionState) -> Literal["revise", "finalize"]:
196196
latest_critique = critiques[-1]
197197

198198
# Finalize if satisfactory or max iterations reached
199-
if latest_critique.is_satisfactory or iteration >= max_iterations:
199+
# Note: After Temporal serialization, Critique objects become dicts
200+
is_satisfactory = (
201+
latest_critique.get("is_satisfactory", False)
202+
if isinstance(latest_critique, dict)
203+
else latest_critique.is_satisfactory
204+
)
205+
if is_satisfactory or iteration >= max_iterations:
200206
return "finalize"
201207

202208
return "revise"
@@ -217,10 +223,20 @@ def revise(state: ReflectionState) -> dict[str, Any]:
217223
latest_critique = critiques[-1]
218224

219225
# Format the feedback
226+
# Note: After Temporal serialization, Critique objects become dicts
227+
if isinstance(latest_critique, dict):
228+
strengths = latest_critique.get("strengths", [])
229+
weaknesses = latest_critique.get("weaknesses", [])
230+
suggestions = latest_critique.get("suggestions", [])
231+
else:
232+
strengths = latest_critique.strengths
233+
weaknesses = latest_critique.weaknesses
234+
suggestions = latest_critique.suggestions
235+
220236
feedback = f"""
221-
Strengths: {', '.join(latest_critique.strengths)}
222-
Weaknesses: {', '.join(latest_critique.weaknesses)}
223-
Suggestions: {', '.join(latest_critique.suggestions)}
237+
Strengths: {', '.join(strengths)}
238+
Weaknesses: {', '.join(weaknesses)}
239+
Suggestions: {', '.join(suggestions)}
224240
"""
225241

226242
revise_prompt = ChatPromptTemplate.from_messages(
@@ -264,7 +280,16 @@ def finalize(state: ReflectionState) -> dict[str, Any]:
264280
iteration = state.get("iteration", 1)
265281

266282
# Get final score
267-
final_score = critiques[-1].quality_score if critiques else 0
283+
# Note: After Temporal serialization, Critique objects become dicts
284+
if critiques:
285+
last_critique = critiques[-1]
286+
final_score = (
287+
last_critique.get("quality_score", 0)
288+
if isinstance(last_critique, dict)
289+
else last_critique.quality_score
290+
)
291+
else:
292+
final_score = 0
268293

269294
summary = f"""
270295
Content finalized after {iteration} iteration(s).
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Tests for the agentic_rag LangGraph sample."""
2+
3+
import uuid
4+
5+
from temporalio.client import Client
6+
from temporalio.contrib.langgraph import LangGraphPlugin
7+
from temporalio.worker import Worker
8+
9+
from langgraph_samples.agentic_rag.graph import build_agentic_rag_graph
10+
from langgraph_samples.agentic_rag.workflow import AgenticRAGWorkflow
11+
12+
from .conftest import requires_openai
13+
14+
15+
@requires_openai
16+
async def test_agentic_rag_workflow(client: Client) -> None:
17+
"""Test agentic RAG workflow with a knowledge base query.
18+
19+
This test requires OPENAI_API_KEY to be set.
20+
"""
21+
task_queue = f"agentic-rag-test-{uuid.uuid4()}"
22+
23+
plugin = LangGraphPlugin(graphs={"agentic_rag": build_agentic_rag_graph})
24+
25+
async with Worker(
26+
client,
27+
task_queue=task_queue,
28+
workflows=[AgenticRAGWorkflow],
29+
plugins=[plugin],
30+
):
31+
result = await client.execute_workflow(
32+
AgenticRAGWorkflow.run,
33+
"What are AI agents?",
34+
id=f"agentic-rag-{uuid.uuid4()}",
35+
task_queue=task_queue,
36+
)
37+
38+
# The result should contain messages with retrieved context
39+
assert "messages" in result
40+
assert len(result["messages"]) > 0
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Tests for the approval_workflow LangGraph sample."""
2+
3+
import uuid
4+
5+
from temporalio.client import Client, WorkflowHandle
6+
from temporalio.contrib.langgraph import LangGraphPlugin
7+
from temporalio.worker import Worker
8+
9+
from langgraph_samples.approval_workflow.activities import notify_approver
10+
from langgraph_samples.approval_workflow.graph import build_approval_graph
11+
from langgraph_samples.approval_workflow.workflow import (
12+
ApprovalRequest,
13+
ApprovalWorkflow,
14+
)
15+
16+
17+
async def test_approval_workflow_approved(client: Client) -> None:
18+
"""Test approval workflow when request is approved."""
19+
task_queue = f"approval-test-{uuid.uuid4()}"
20+
21+
plugin = LangGraphPlugin(graphs={"approval_workflow": build_approval_graph})
22+
23+
async with Worker(
24+
client,
25+
task_queue=task_queue,
26+
workflows=[ApprovalWorkflow],
27+
activities=[notify_approver],
28+
plugins=[plugin],
29+
):
30+
# Start the workflow
31+
handle: WorkflowHandle[ApprovalWorkflow, dict] = await client.start_workflow(
32+
ApprovalWorkflow.run,
33+
ApprovalRequest(request_type="expense", amount=500.0),
34+
id=f"approval-{uuid.uuid4()}",
35+
task_queue=task_queue,
36+
)
37+
38+
# Wait for the workflow to reach the approval point
39+
import asyncio
40+
41+
for _ in range(20):
42+
status = await handle.query(ApprovalWorkflow.get_status)
43+
if status == "waiting_for_approval":
44+
break
45+
await asyncio.sleep(0.1)
46+
47+
assert status == "waiting_for_approval"
48+
49+
# Query the pending approval
50+
pending = await handle.query(ApprovalWorkflow.get_pending_approval)
51+
assert pending is not None
52+
assert pending["amount"] == 500.0
53+
assert pending["risk_level"] == "medium"
54+
55+
# Send approval signal
56+
await handle.signal(
57+
ApprovalWorkflow.provide_approval,
58+
{"approved": True, "reason": "Looks good", "approver": "manager"},
59+
)
60+
61+
# Wait for result
62+
result = await handle.result()
63+
64+
assert result["approved"] is True
65+
assert result["executed"] is True
66+
assert "Successfully processed" in result["result"]
67+
assert "manager" in result["result"]
68+
69+
70+
async def test_approval_workflow_rejected(client: Client) -> None:
71+
"""Test approval workflow when request is rejected."""
72+
task_queue = f"approval-test-{uuid.uuid4()}"
73+
74+
plugin = LangGraphPlugin(graphs={"approval_workflow": build_approval_graph})
75+
76+
async with Worker(
77+
client,
78+
task_queue=task_queue,
79+
workflows=[ApprovalWorkflow],
80+
activities=[notify_approver],
81+
plugins=[plugin],
82+
):
83+
handle: WorkflowHandle[ApprovalWorkflow, dict] = await client.start_workflow(
84+
ApprovalWorkflow.run,
85+
ApprovalRequest(request_type="purchase", amount=5000.0),
86+
id=f"approval-{uuid.uuid4()}",
87+
task_queue=task_queue,
88+
)
89+
90+
# Wait for approval state
91+
import asyncio
92+
93+
for _ in range(20):
94+
status = await handle.query(ApprovalWorkflow.get_status)
95+
if status == "waiting_for_approval":
96+
break
97+
await asyncio.sleep(0.1)
98+
99+
# Verify high risk level for large amount
100+
pending = await handle.query(ApprovalWorkflow.get_pending_approval)
101+
assert pending is not None
102+
assert pending["risk_level"] == "high"
103+
104+
# Reject the request
105+
await handle.signal(
106+
ApprovalWorkflow.provide_approval,
107+
{"approved": False, "reason": "Budget exceeded", "approver": "cfo"},
108+
)
109+
110+
result = await handle.result()
111+
112+
assert result["approved"] is False
113+
assert result["executed"] is False
114+
assert "rejected" in result["result"]
115+
assert "cfo" in result["result"]
116+
117+
118+
async def test_approval_workflow_low_risk(client: Client) -> None:
119+
"""Test approval workflow with low risk amount."""
120+
task_queue = f"approval-test-{uuid.uuid4()}"
121+
122+
plugin = LangGraphPlugin(graphs={"approval_workflow": build_approval_graph})
123+
124+
async with Worker(
125+
client,
126+
task_queue=task_queue,
127+
workflows=[ApprovalWorkflow],
128+
activities=[notify_approver],
129+
plugins=[plugin],
130+
):
131+
handle: WorkflowHandle[ApprovalWorkflow, dict] = await client.start_workflow(
132+
ApprovalWorkflow.run,
133+
ApprovalRequest(request_type="supplies", amount=25.0),
134+
id=f"approval-{uuid.uuid4()}",
135+
task_queue=task_queue,
136+
)
137+
138+
# Wait for approval state
139+
import asyncio
140+
141+
for _ in range(20):
142+
status = await handle.query(ApprovalWorkflow.get_status)
143+
if status == "waiting_for_approval":
144+
break
145+
await asyncio.sleep(0.1)
146+
147+
# Verify low risk level
148+
pending = await handle.query(ApprovalWorkflow.get_pending_approval)
149+
assert pending is not None
150+
assert pending["risk_level"] == "low"
151+
152+
# Approve
153+
await handle.signal(
154+
ApprovalWorkflow.provide_approval,
155+
{"approved": True, "reason": "Auto-approved", "approver": "system"},
156+
)
157+
158+
result = await handle.result()
159+
assert result["approved"] is True
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Shared test fixtures for LangGraph samples."""
2+
3+
import os
4+
5+
# Disable LangSmith tracing for tests to avoid rate limit issues
6+
os.environ["LANGCHAIN_TRACING_V2"] = "false"
7+
8+
import pytest
9+
from temporalio.contrib.langgraph._graph_registry import get_global_registry
10+
11+
12+
@pytest.fixture(autouse=True)
13+
def clear_registry() -> None:
14+
"""Clear the global graph registry before each test."""
15+
get_global_registry().clear()
16+
17+
18+
def has_openai_api_key() -> bool:
19+
"""Check if OpenAI API key is available."""
20+
return bool(os.environ.get("OPENAI_API_KEY"))
21+
22+
23+
# Skip marker for tests that require OpenAI API
24+
requires_openai = pytest.mark.skipif(
25+
not has_openai_api_key(),
26+
reason="OPENAI_API_KEY environment variable not set",
27+
)

0 commit comments

Comments
 (0)