Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions docs/integrations/langgraph-integration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ if final_answer:
correctness_judge_response = await judge_correctness(
scenario, traj.final_answer.answer
)
traj.metrics["correct"] = float(correctness_judge_response.accept)
traj.metrics["correct"] = correctness_judge_response.accept
```

### Training Loop with LangGraph Integration
Expand Down Expand Up @@ -179,7 +179,7 @@ for batch in training_iterator:
# Apply RULER scoring
judged_groups = []
for group in finished_groups:
judged_group = await ruler_score_group(group, "openai/o4-mini", debug=True)
judged_group = await ruler_score_group(group, "openai/o4-mini")
judged_groups.append(judged_group)

# Train the model
Expand Down Expand Up @@ -326,7 +326,7 @@ async def judge_correctness(scenario: Scenario, answer: str) -> CorrectnessJudge
You are given a question, the reference answer, and an answer generated by an AI assistant.
Your task is to decide whether the AI answer is correct and should be accepted.
""")

messages = [
{"role": "system", "content": system_prompt},
{
Expand All @@ -338,13 +338,13 @@ async def judge_correctness(scenario: Scenario, answer: str) -> CorrectnessJudge
),
},
]

response = await acompletion(
model="openai/gpt-4o-mini",
messages=messages,
response_format=CorrectnessJudgeResponse,
)

return CorrectnessJudgeResponse.model_validate_json(
response.choices[0].message.content or "{}"
)
Expand All @@ -354,7 +354,7 @@ async def judge_correctness(scenario: Scenario, answer: str) -> CorrectnessJudge
async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTrajectory:
scenario = email_scenario.scenario
MAX_TURNS = 10

traj = ProjectTrajectory(
reward=0.0,
messages_and_choices=[],
Expand All @@ -363,62 +363,62 @@ async def rollout(model: art.Model, email_scenario: EmailScenario) -> ProjectTra
"step": email_scenario.step,
},
)

system_prompt = dedent(f"""
You are an email search agent. Use the tools to search emails and find answers.
User's email address: {scenario.inbox_address}
Today's date: {scenario.query_date}

When you find the answer, use return_final_answer_tool with the answer and source message IDs.
""")

final_answer = None

@tool
def search_inbox_tool(keywords: List[str]) -> List[dict]:
"""Search inbox for emails matching keywords"""
results = search_emails(scenario.inbox_address, keywords, scenario.query_date)
return [asdict(result) for result in results]

@tool
def read_email_tool(message_id: str) -> dict | None:
"""Read a specific email by message ID"""
email = read_email(message_id)
return email.model_dump() if email else None

@tool
def return_final_answer_tool(answer: str, reference_message_ids: List[str]) -> dict:
"""Return final answer with source message IDs"""
nonlocal final_answer
final_answer = FinalAnswer(answer=answer, source_ids=reference_message_ids)
return final_answer.model_dump()

tools = [search_inbox_tool, read_email_tool, return_final_answer_tool]
chat_model = init_chat_model(model.name, temperature=1.0)
react_agent = create_react_agent(chat_model, tools)

try:
config = {
"configurable": {"thread_id": str(uuid.uuid4())},
"recursion_limit": MAX_TURNS,
}

await react_agent.ainvoke({
"messages": [
SystemMessage(content=system_prompt),
HumanMessage(content=scenario.question),
]
}, config=config)

if final_answer:
traj.final_answer = final_answer
correctness_judge_response = await judge_correctness(scenario, final_answer.answer)
traj.metrics["correct"] = float(correctness_judge_response.accept)

except Exception as e:
print(f"Error running agent: {e}")
traj.messages_and_choices.append({"role": "assistant", "content": f"Error: {str(e)}"})

return traj

# Main training function
Expand All @@ -433,17 +433,17 @@ async def main():
query_date="2024-01-20"
),
Scenario(
id="2",
id="2",
question="Look for urgent project updates",
answer="Project deadline moved to next month",
inbox_address="user@company.com",
query_date="2024-01-20"
),
]

# Register model with backend
await model.register(backend)

# Training configuration
training_config = {
"groups_per_step": 2,
Expand All @@ -452,19 +452,19 @@ async def main():
"learning_rate": 1e-5,
"max_steps": 5,
}

# Training iterator
training_iterator = iterate_dataset(
training_scenarios,
groups_per_step=training_config["groups_per_step"],
num_epochs=training_config["num_epochs"],
initial_step=await model.get_step(),
)

# Training loop
for batch in training_iterator:
print(f"Training step {batch.step}, epoch {batch.epoch}")

# Create trajectory groups
groups = []
for scenario in batch.items:
Expand All @@ -476,22 +476,28 @@ async def main():
for _ in range(training_config["rollouts_per_group"])
])
)

# Gather trajectories
finished_groups = await art.gather_trajectory_groups(
groups,
pbar_desc="gather",
max_exceptions=training_config["rollouts_per_group"] * len(batch.items),
)


# Apply RULER scoring
judged_groups = []
for group in finished_groups:
judged_group = await ruler_score_group(group, "openai/o4-mini")
judged_groups.append(judged_group)

# Train model
await model.train(
finished_groups,
judged_groups,
config=art.TrainConfig(learning_rate=training_config["learning_rate"]),
)

print(f"Completed training step {batch.step}")

if batch.step >= training_config["max_steps"]:
break

Expand All @@ -503,7 +509,7 @@ This complete example shows how to:

1. **Set up the environment** with model, backend, and data structures
2. **Define custom tools** for email search and retrieval
3. **Create a LangGraph ReAct agent** with proper configuration
3. **Create a LangGraph ReAct agent** with proper configuration
4. **Implement trajectory tracking** with custom reward scoring
5. **Run the full training loop** with proper error handling
6. **Use wrap_rollout** to automatically capture agent interactions
Expand Down