Skip to content

Commit 8e5d219

Browse files
JoshuaC215peterkeppertpeterkeppert-anasoft
authored
Add a BG Task Agent to show CustomData usage (atlanhq#81)
* Implement bg-task-agent * fix bg tasks display logic (atlanhq#78) * fix bg tasks display logic * Bug fixes and cleanup * support parallel task runs * Clean up logic and docs for the TaskData handling --------- Co-authored-by: ANASOFT\keppert <peter.keppert@anasoft.com> Co-authored-by: Joshua Carroll <carroll.joshk@gmail.com> --------- Co-authored-by: peterkeppert <30529458+peterkeppert@users.noreply.github.com> Co-authored-by: ANASOFT\keppert <peter.keppert@anasoft.com>
1 parent acebd43 commit 8e5d219

File tree

5 files changed

+204
-0
lines changed

5 files changed

+204
-0
lines changed

src/agents/agents.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from langgraph.graph.state import CompiledStateGraph
22

3+
from agents.bg_task_agent.bg_task_agent import bg_task_agent
34
from agents.chatbot import chatbot
45
from agents.research_assistant import research_assistant
56

@@ -9,4 +10,5 @@
910
agents: dict[str, CompiledStateGraph] = {
1011
"chatbot": chatbot,
1112
"research-assistant": research_assistant,
13+
"bg-task-agent": bg_task_agent,
1214
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import asyncio
2+
3+
from langchain_core.language_models.chat_models import BaseChatModel
4+
from langchain_core.messages import AIMessage
5+
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
6+
from langgraph.checkpoint.memory import MemorySaver
7+
from langgraph.graph import END, MessagesState, StateGraph
8+
9+
from agents.bg_task_agent.task import Task
10+
from agents.models import models
11+
12+
13+
class AgentState(MessagesState, total=False):
14+
"""`total=False` is PEP589 specs.
15+
16+
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
17+
"""
18+
19+
20+
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
21+
preprocessor = RunnableLambda(
22+
lambda state: state["messages"],
23+
name="StateModifier",
24+
)
25+
return preprocessor | model
26+
27+
28+
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
29+
m = models[config["configurable"].get("model", "gpt-4o-mini")]
30+
model_runnable = wrap_model(m)
31+
response = await model_runnable.ainvoke(state, config)
32+
33+
# We return a list, because this will get added to the existing list
34+
return {"messages": [response]}
35+
36+
37+
async def bg_task(state: AgentState, config: RunnableConfig) -> AgentState:
38+
task1 = Task("Simple task 1...")
39+
task2 = Task("Simple task 2...")
40+
41+
await task1.start(config=config)
42+
await asyncio.sleep(2)
43+
await task2.start(config=config)
44+
await asyncio.sleep(2)
45+
await task1.write_data(config=config, data={"status": "Still running..."})
46+
await asyncio.sleep(2)
47+
await task2.finish(result="error", config=config, data={"output": 42})
48+
await asyncio.sleep(2)
49+
await task1.finish(result="success", config=config, data={"output": 42})
50+
return {"messages": []}
51+
52+
53+
# Define the graph
54+
agent = StateGraph(AgentState)
55+
agent.add_node("model", acall_model)
56+
agent.add_node("bg_task", bg_task)
57+
agent.set_entry_point("bg_task")
58+
59+
agent.add_edge("bg_task", "model")
60+
agent.add_edge("model", END)
61+
62+
bg_task_agent = agent.compile(
63+
checkpointer=MemorySaver(),
64+
)

src/agents/bg_task_agent/task.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Literal
2+
from uuid import uuid4
3+
4+
from langchain_core.messages import BaseMessage
5+
from langchain_core.runnables import RunnableConfig
6+
7+
from agents.utils import CustomData
8+
from schema.task_data import TaskData
9+
10+
11+
class Task:
12+
def __init__(self, task_name: str) -> None:
13+
self.name = task_name
14+
self.id = str(uuid4())
15+
self.state: Literal["new", "running", "complete"] = "new"
16+
self.result: Literal["success", "error"] | None = None
17+
18+
async def _generate_and_dispatch_message(self, config: RunnableConfig, data: dict):
19+
task_data = TaskData(name=self.name, run_id=self.id, state=self.state, data=data)
20+
if self.result:
21+
task_data.result = self.result
22+
task_custom_data = CustomData(
23+
type=self.name,
24+
data=task_data.model_dump(),
25+
)
26+
await task_custom_data.adispatch()
27+
return task_custom_data.to_langchain()
28+
29+
async def start(self, config: RunnableConfig, data: dict = {}) -> BaseMessage:
30+
self.state = "new"
31+
task_message = await self._generate_and_dispatch_message(config, data)
32+
return task_message
33+
34+
async def write_data(self, config: RunnableConfig, data: dict) -> BaseMessage:
35+
if self.state == "complete":
36+
raise ValueError("Only incomplete tasks can output data.")
37+
self.state = "running"
38+
task_message = await self._generate_and_dispatch_message(config, data)
39+
return task_message
40+
41+
async def finish(
42+
self, result: Literal["success", "error"], config: RunnableConfig, data: dict = {}
43+
) -> BaseMessage:
44+
self.state = "complete"
45+
self.result = result
46+
task_message = await self._generate_and_dispatch_message(config, data)
47+
return task_message

src/schema/task_data.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Any, Literal
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class TaskData(BaseModel):
7+
name: str | None = Field(
8+
description="Name of the task.", default=None, examples=["Check input safety"]
9+
)
10+
run_id: str = Field(
11+
description="ID of the task run to pair state updates to.",
12+
default="",
13+
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
14+
)
15+
state: Literal["new", "running", "complete"] | None = Field(
16+
description="Current state of given task instance.",
17+
default=None,
18+
examples=["running"],
19+
)
20+
result: Literal["success", "error"] | None = Field(
21+
description="Result of given task instance.",
22+
default=None,
23+
examples=["running"],
24+
)
25+
data: dict[str, Any] = Field(
26+
description="Additional data generated by the task.",
27+
default={},
28+
)
29+
30+
def completed(self) -> bool:
31+
return self.state == "complete"
32+
33+
def completed_with_error(self) -> bool:
34+
return self.state == "complete" and self.result == "error"

src/streamlit_app.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from client import AgentClient
99
from schema import ChatHistory, ChatMessage
10+
from schema.task_data import TaskData
1011

1112
# A Streamlit app for interacting with the langgraph agent via a simple chat interface.
1213
# The app has three main functions which are all run async:
@@ -82,6 +83,7 @@ async def main() -> None:
8283
options=[
8384
"research-assistant",
8485
"chatbot",
86+
"bg-task-agent",
8587
],
8688
)
8789
use_streaming = st.toggle("Stream results", value=True)
@@ -266,6 +268,61 @@ async def draw_messages(
266268
status.write(tool_result.content)
267269
status.update(state="complete")
268270

271+
case "custom":
272+
# This is an implementation of the TaskData example for CustomData.
273+
# An agent can write a CustomData object to the message stream, and
274+
# it's passed to the client for rendering. To see this in practice,
275+
# run the app with the `bg-task-agent` agent.
276+
277+
# This is provided as an example, you may want to write your own
278+
# CustomData types and handlers. This section will be skipped for
279+
# any other agents that don't send CustomData.
280+
task_data = TaskData.model_validate(msg.custom_data)
281+
282+
# If we're rendering new messages, store the message in session state
283+
if is_new:
284+
st.session_state.messages.append(msg)
285+
286+
# If the last message type was not Task, create a new chat message
287+
# and container for task messages
288+
if last_message_type != "task":
289+
last_message_type = "task"
290+
st.session_state.last_message = st.chat_message(
291+
name="task", avatar=":material/manufacturing:"
292+
)
293+
with st.session_state.last_message:
294+
status = st.status("")
295+
current_task_data: dict[str, TaskData] = {}
296+
297+
status_str = f"Task **{task_data.name}** "
298+
match task_data.state:
299+
case "new":
300+
status_str += "has :blue[started]. Input:"
301+
case "running":
302+
status_str += "wrote:"
303+
case "complete":
304+
if task_data.result == "success":
305+
status_str += ":green[completed successfully]. Output:"
306+
else:
307+
status_str += ":red[ended with error]. Output:"
308+
status.write(status_str)
309+
status.write(task_data.data)
310+
status.write("---")
311+
if task_data.run_id not in current_task_data:
312+
# Status label always shows the last newly started task
313+
status.update(label=f"""Task: {task_data.name}""")
314+
current_task_data[task_data.run_id] = task_data
315+
# Status is "running" until all tasks have completed
316+
if not any(entry.completed() for entry in current_task_data.values()):
317+
state = "running"
318+
# Status is "error" if any task has errored
319+
elif any(entry.completed_with_error() for entry in current_task_data.values()):
320+
state = "error"
321+
# Status is "complete" if all tasks have completed successfully
322+
else:
323+
state = "complete"
324+
status.update(state=state)
325+
269326
# In case of an unexpected message type, log an error and stop
270327
case _:
271328
st.error(f"Unexpected ChatMessage type: {msg.type}")

0 commit comments

Comments
 (0)