@@ -98,21 +98,116 @@ def _set_conditional_node_edges(self):
98
98
except :
99
99
node .false_node_name = None
100
100
101
- def _execute_standard (self , initial_state : dict ) -> Tuple [dict , list ]:
102
- """
103
- Executes the graph by traversing nodes starting from the
104
- entry point using the standard method.
101
+ def _get_node_by_name (self , node_name : str ):
102
+ """Returns a node instance by its name."""
103
+ return next (node for node in self .nodes if node .node_name == node_name )
105
104
106
- Args:
107
- initial_state (dict): The initial state to pass to the entry point node.
105
+ def _update_source_info (self , current_node , state ):
106
+ """Updates source type and source information from FetchNode."""
107
+ source_type = None
108
+ source = []
109
+ prompt = None
110
+
111
+ if current_node .__class__ .__name__ == "FetchNode" :
112
+ source_type = list (state .keys ())[1 ]
113
+ if state .get ("user_prompt" , None ):
114
+ prompt = state ["user_prompt" ] if isinstance (state ["user_prompt" ], str ) else None
115
+
116
+ if source_type == "local_dir" :
117
+ source_type = "html_dir"
118
+ elif source_type == "url" :
119
+ if isinstance (state [source_type ], list ):
120
+ source .extend (url for url in state [source_type ] if isinstance (url , str ))
121
+ elif isinstance (state [source_type ], str ):
122
+ source .append (state [source_type ])
123
+
124
+ return source_type , source , prompt
125
+
126
+ def _get_model_info (self , current_node ):
127
+ """Extracts LLM and embedder model information from the node."""
128
+ llm_model = None
129
+ llm_model_name = None
130
+ embedder_model = None
108
131
109
- Returns:
110
- Tuple[dict, list]: A tuple containing the final state and a list of execution info.
132
+ if hasattr (current_node , "llm_model" ):
133
+ llm_model = current_node .llm_model
134
+ if hasattr (llm_model , "model_name" ):
135
+ llm_model_name = llm_model .model_name
136
+ elif hasattr (llm_model , "model" ):
137
+ llm_model_name = llm_model .model
138
+ elif hasattr (llm_model , "model_id" ):
139
+ llm_model_name = llm_model .model_id
140
+
141
+ if hasattr (current_node , "embedder_model" ):
142
+ embedder_model = current_node .embedder_model
143
+ if hasattr (embedder_model , "model_name" ):
144
+ embedder_model = embedder_model .model_name
145
+ elif hasattr (embedder_model , "model" ):
146
+ embedder_model = embedder_model .model
147
+
148
+ return llm_model , llm_model_name , embedder_model
149
+
150
+ def _get_schema (self , current_node ):
151
+ """Extracts schema information from the node configuration."""
152
+ if not hasattr (current_node , "node_config" ):
153
+ return None
154
+
155
+ if not isinstance (current_node .node_config , dict ):
156
+ return None
157
+
158
+ schema_config = current_node .node_config .get ("schema" )
159
+ if not schema_config or isinstance (schema_config , dict ):
160
+ return None
161
+
162
+ try :
163
+ return schema_config .schema ()
164
+ except Exception :
165
+ return None
166
+
167
+ def _execute_node (self , current_node , state , llm_model , llm_model_name ):
168
+ """Executes a single node and returns execution information."""
169
+ curr_time = time .time ()
170
+
171
+ with self .callback_manager .exclusive_get_callback (llm_model , llm_model_name ) as cb :
172
+ result = current_node .execute (state )
173
+ node_exec_time = time .time () - curr_time
174
+
175
+ cb_data = None
176
+ if cb is not None :
177
+ cb_data = {
178
+ "node_name" : current_node .node_name ,
179
+ "total_tokens" : cb .total_tokens ,
180
+ "prompt_tokens" : cb .prompt_tokens ,
181
+ "completion_tokens" : cb .completion_tokens ,
182
+ "successful_requests" : cb .successful_requests ,
183
+ "total_cost_USD" : cb .total_cost ,
184
+ "exec_time" : node_exec_time ,
185
+ }
186
+
187
+ return result , node_exec_time , cb_data
188
+
189
+ def _get_next_node (self , current_node , result ):
190
+ """Determines the next node to execute based on current node type and result."""
191
+ if current_node .node_type == "conditional_node" :
192
+ node_names = {node .node_name for node in self .nodes }
193
+ if result in node_names :
194
+ return result
195
+ elif result is None :
196
+ return None
197
+ raise ValueError (
198
+ f"Conditional Node returned a node name '{ result } ' that does not exist in the graph"
199
+ )
200
+
201
+ return self .edges .get (current_node .node_name )
202
+
203
+ def _execute_standard (self , initial_state : dict ) -> Tuple [dict , list ]:
204
+ """
205
+ Executes the graph by traversing nodes starting from the entry point using the standard method.
111
206
"""
112
207
current_node_name = self .entry_point
113
208
state = initial_state
114
-
115
- # variables for tracking execution info
209
+
210
+ # Tracking variables
116
211
total_exec_time = 0.0
117
212
exec_info = []
118
213
cb_total = {
@@ -134,104 +229,51 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
134
229
schema = None
135
230
136
231
while current_node_name :
137
- curr_time = time .time ()
138
- current_node = next (node for node in self .nodes if node .node_name == current_node_name )
139
-
140
- if current_node .__class__ .__name__ == "FetchNode" :
141
- source_type = list (state .keys ())[1 ]
142
- if state .get ("user_prompt" , None ):
143
- prompt = state ["user_prompt" ] if isinstance (state ["user_prompt" ], str ) else None
144
-
145
- if source_type == "local_dir" :
146
- source_type = "html_dir"
147
- elif source_type == "url" :
148
- if isinstance (state [source_type ], list ):
149
- for url in state [source_type ]:
150
- if isinstance (url , str ):
151
- source .append (url )
152
- elif isinstance (state [source_type ], str ):
153
- source .append (state [source_type ])
154
-
155
- if hasattr (current_node , "llm_model" ) and llm_model is None :
156
- llm_model = current_node .llm_model
157
- if hasattr (llm_model , "model_name" ):
158
- llm_model_name = llm_model .model_name
159
- elif hasattr (llm_model , "model" ):
160
- llm_model_name = llm_model .model
161
- elif hasattr (llm_model , "model_id" ):
162
- llm_model_name = llm_model .model_id
163
-
164
- if hasattr (current_node , "embedder_model" ) and embedder_model is None :
165
- embedder_model = current_node .embedder_model
166
- if hasattr (embedder_model , "model_name" ):
167
- embedder_model = embedder_model .model_name
168
- elif hasattr (embedder_model , "model" ):
169
- embedder_model = embedder_model .model
170
-
171
- if hasattr (current_node , "node_config" ):
172
- if isinstance (current_node .node_config ,dict ):
173
- if current_node .node_config .get ("schema" , None ) and schema is None :
174
- if not isinstance (current_node .node_config ["schema" ], dict ):
175
- try :
176
- schema = current_node .node_config ["schema" ].schema ()
177
- except Exception as e :
178
- schema = None
179
-
180
- with self .callback_manager .exclusive_get_callback (llm_model , llm_model_name ) as cb :
181
- try :
182
- result = current_node .execute (state )
183
- except Exception as e :
184
- error_node = current_node .node_name
185
- graph_execution_time = time .time () - start_time
186
- log_graph_execution (
187
- graph_name = self .graph_name ,
188
- source = source ,
189
- prompt = prompt ,
190
- schema = schema ,
191
- llm_model = llm_model_name ,
192
- embedder_model = embedder_model ,
193
- source_type = source_type ,
194
- execution_time = graph_execution_time ,
195
- error_node = error_node ,
196
- exception = str (e )
197
- )
198
- raise e
199
- node_exec_time = time .time () - curr_time
232
+ current_node = self ._get_node_by_name (current_node_name )
233
+
234
+ # Update source information if needed
235
+ if source_type is None :
236
+ source_type , source , prompt = self ._update_source_info (current_node , state )
237
+
238
+ # Get model information if needed
239
+ if llm_model is None :
240
+ llm_model , llm_model_name , embedder_model = self ._get_model_info (current_node )
241
+
242
+ # Get schema if needed
243
+ if schema is None :
244
+ schema = self ._get_schema (current_node )
245
+
246
+ try :
247
+ result , node_exec_time , cb_data = self ._execute_node (
248
+ current_node , state , llm_model , llm_model_name
249
+ )
200
250
total_exec_time += node_exec_time
201
251
202
- if cb is not None :
203
- cb_data = {
204
- "node_name" : current_node .node_name ,
205
- "total_tokens" : cb .total_tokens ,
206
- "prompt_tokens" : cb .prompt_tokens ,
207
- "completion_tokens" : cb .completion_tokens ,
208
- "successful_requests" : cb .successful_requests ,
209
- "total_cost_USD" : cb .total_cost ,
210
- "exec_time" : node_exec_time ,
211
- }
212
-
252
+ if cb_data :
213
253
exec_info .append (cb_data )
214
-
215
- cb_total ["total_tokens" ] += cb_data ["total_tokens" ]
216
- cb_total ["prompt_tokens" ] += cb_data ["prompt_tokens" ]
217
- cb_total ["completion_tokens" ] += cb_data ["completion_tokens" ]
218
- cb_total ["successful_requests" ] += cb_data ["successful_requests" ]
219
- cb_total ["total_cost_USD" ] += cb_data ["total_cost_USD" ]
220
-
221
- if current_node .node_type == "conditional_node" :
222
- node_names = {node .node_name for node in self .nodes }
223
- if result in node_names :
224
- current_node_name = result
225
- elif result is None :
226
- current_node_name = None
227
- else :
228
- raise ValueError (f"Conditional Node returned a node name '{ result } ' that does not exist in the graph" )
229
-
230
- elif current_node_name in self .edges :
231
- current_node_name = self .edges [current_node_name ]
232
- else :
233
- current_node_name = None
234
-
254
+ for key in cb_total :
255
+ cb_total [key ] += cb_data [key ]
256
+
257
+ current_node_name = self ._get_next_node (current_node , result )
258
+
259
+ except Exception as e :
260
+ error_node = current_node .node_name
261
+ graph_execution_time = time .time () - start_time
262
+ log_graph_execution (
263
+ graph_name = self .graph_name ,
264
+ source = source ,
265
+ prompt = prompt ,
266
+ schema = schema ,
267
+ llm_model = llm_model_name ,
268
+ embedder_model = embedder_model ,
269
+ source_type = source_type ,
270
+ execution_time = graph_execution_time ,
271
+ error_node = error_node ,
272
+ exception = str (e )
273
+ )
274
+ raise e
275
+
276
+ # Add total results to execution info
235
277
exec_info .append ({
236
278
"node_name" : "TOTAL RESULT" ,
237
279
"total_tokens" : cb_total ["total_tokens" ],
@@ -242,6 +284,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
242
284
"exec_time" : total_exec_time ,
243
285
})
244
286
287
+ # Log final execution results
245
288
graph_execution_time = time .time () - start_time
246
289
response = state .get ("answer" , None ) if source_type == "url" else None
247
290
content = state .get ("parsed_doc" , None ) if response is not None else None
@@ -300,3 +343,4 @@ def append_node(self, node):
300
343
self .raw_edges .append ((last_node , node ))
301
344
self .nodes .append (node )
302
345
self .edges = self ._create_edges ({e for e in self .raw_edges })
346
+
0 commit comments