9
9
10
10
from dataclasses import dataclass , field
11
11
from pathlib import Path
12
- from typing import Annotated
13
12
14
13
import logfire
15
- from devtools import debug
16
- from pydantic_graph import BaseNode , Edge , End , Graph , GraphRunContext , HistoryStep
14
+ from groq import BaseModel
15
+ from pydantic_graph import (
16
+ BaseNode ,
17
+ End ,
18
+ Graph ,
19
+ GraphRunContext ,
20
+ )
21
+ from pydantic_graph .persistence .file import FileStatePersistence
17
22
18
23
from pydantic_ai import Agent
19
24
from pydantic_ai .format_as_xml import format_as_xml
@@ -41,22 +46,23 @@ async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
41
46
)
42
47
ctx .state .ask_agent_messages += result .all_messages ()
43
48
ctx .state .question = result .data
44
- return Answer ()
49
+ return Answer (result . data )
45
50
46
51
47
52
@dataclass
48
53
class Answer (BaseNode [QuestionState ]):
49
- answer : str | None = None
54
+ question : str
50
55
51
56
async def run (self , ctx : GraphRunContext [QuestionState ]) -> Evaluate :
52
- assert self .answer is not None
53
- return Evaluate (self . answer )
57
+ answer = input ( f' { self .question } : ' )
58
+ return Evaluate (answer )
54
59
55
60
56
- @dataclass
57
- class EvaluationResult :
61
+ class EvaluationResult (BaseModel , use_attribute_docstrings = True ):
58
62
correct : bool
63
+ """Whether the answer is correct."""
59
64
comment : str
65
+ """Comment on the answer, reprimand the user if the answer is wrong."""
60
66
61
67
62
68
evaluate_agent = Agent (
@@ -67,101 +73,76 @@ class EvaluationResult:
67
73
68
74
69
75
@dataclass
70
- class Evaluate (BaseNode [QuestionState ]):
76
+ class Evaluate (BaseNode [QuestionState , None , str ]):
71
77
answer : str
72
78
73
79
async def run (
74
80
self ,
75
81
ctx : GraphRunContext [QuestionState ],
76
- ) -> Congratulate | Reprimand :
82
+ ) -> End [ str ] | Reprimand :
77
83
assert ctx .state .question is not None
78
84
result = await evaluate_agent .run (
79
85
format_as_xml ({'question' : ctx .state .question , 'answer' : self .answer }),
80
86
message_history = ctx .state .evaluate_agent_messages ,
81
87
)
82
88
ctx .state .evaluate_agent_messages += result .all_messages ()
83
89
if result .data .correct :
84
- return Congratulate (result .data .comment )
90
+ return End (result .data .comment )
85
91
else :
86
92
return Reprimand (result .data .comment )
87
93
88
94
89
- @dataclass
90
- class Congratulate (BaseNode [QuestionState , None , None ]):
91
- comment : str
92
-
93
- async def run (
94
- self , ctx : GraphRunContext [QuestionState ]
95
- ) -> Annotated [End , Edge (label = 'success' )]:
96
- print (f'Correct answer! { self .comment } ' )
97
- return End (None )
98
-
99
-
100
95
@dataclass
101
96
class Reprimand (BaseNode [QuestionState ]):
102
97
comment : str
103
98
104
99
async def run (self , ctx : GraphRunContext [QuestionState ]) -> Ask :
105
100
print (f'Comment: { self .comment } ' )
106
- # > Comment: Vichy is no longer the capital of France.
107
101
ctx .state .question = None
108
102
return Ask ()
109
103
110
104
111
105
question_graph = Graph (
112
- nodes = (Ask , Answer , Evaluate , Congratulate , Reprimand ), state_type = QuestionState
106
+ nodes = (Ask , Answer , Evaluate , Reprimand ), state_type = QuestionState
113
107
)
114
108
115
109
116
110
async def run_as_continuous ():
117
111
state = QuestionState ()
118
112
node = Ask ()
119
- history : list [HistoryStep [QuestionState , None ]] = []
120
- with logfire .span ('run questions graph' ):
121
- while True :
122
- node = await question_graph .next (node , history , state = state )
123
- if isinstance (node , End ):
124
- debug ([e .data_snapshot () for e in history ])
125
- break
126
- elif isinstance (node , Answer ):
127
- assert state .question
128
- node .answer = input (f'{ state .question } ' )
129
- # otherwise just continue
113
+ end = await question_graph .run (node , state = state )
114
+ print ('END:' , end .output )
130
115
131
116
132
117
async def run_as_cli (answer : str | None ):
133
- history_file = Path ('question_graph_history.json' )
134
- history = (
135
- question_graph .load_history (history_file .read_bytes ())
136
- if history_file .exists ()
137
- else []
138
- )
139
-
140
- if history :
141
- last = history [- 1 ]
142
- assert last .kind == 'node' , 'expected last step to be a node'
143
- state = last .state
144
- assert answer is not None , 'answer is required to continue from history'
145
- node = Answer (answer )
118
+ persistence = FileStatePersistence (Path ('question_graph.json' ))
119
+ persistence .set_graph_types (question_graph )
120
+
121
+ if snapshot := await persistence .load_next ():
122
+ state = snapshot .state
123
+ assert answer is not None , (
124
+ 'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli <answer>"'
125
+ )
126
+ node = Evaluate (answer )
146
127
else :
147
128
state = QuestionState ()
148
129
node = Ask ()
149
- debug (state , node )
130
+ # debug(state, node)
150
131
151
- with logfire . span ( 'run questions graph' ) :
132
+ async with question_graph . iter ( node , state = state , persistence = persistence ) as run :
152
133
while True :
153
- node = await question_graph .next (node , history , state = state )
134
+ node = await run .next ()
154
135
if isinstance (node , End ):
155
- debug ([e .data_snapshot () for e in history ])
136
+ print ('END:' , node .data )
137
+ history = await persistence .load_all ()
138
+ print ('history:' , '\n ' .join (str (e .node ) for e in history ), sep = '\n ' )
156
139
print ('Finished!' )
157
140
break
158
141
elif isinstance (node , Answer ):
159
- print (state .question )
142
+ print (node .question )
160
143
break
161
144
# otherwise just continue
162
145
163
- history_file .write_bytes (question_graph .dump_history (history , indent = 2 ))
164
-
165
146
166
147
if __name__ == '__main__' :
167
148
import asyncio
0 commit comments