Skip to content

Commit 12a6c18

Browse files
committed
feat: refactoring of the base_graph
1 parent 3b2cadc commit 12a6c18

File tree

1 file changed

+149
-105
lines changed

1 file changed

+149
-105
lines changed

scrapegraphai/graphs/base_graph.py

Lines changed: 149 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,116 @@ def _set_conditional_node_edges(self):
9898
except:
9999
node.false_node_name = None
100100

101-
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
102-
"""
103-
Executes the graph by traversing nodes starting from the
104-
entry point using the standard method.
101+
def _get_node_by_name(self, node_name: str):
102+
"""Returns a node instance by its name."""
103+
return next(node for node in self.nodes if node.node_name == node_name)
105104

106-
Args:
107-
initial_state (dict): The initial state to pass to the entry point node.
105+
def _update_source_info(self, current_node, state):
106+
"""Updates source type and source information from FetchNode."""
107+
source_type = None
108+
source = []
109+
prompt = None
110+
111+
if current_node.__class__.__name__ == "FetchNode":
112+
source_type = list(state.keys())[1]
113+
if state.get("user_prompt", None):
114+
prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None
115+
116+
if source_type == "local_dir":
117+
source_type = "html_dir"
118+
elif source_type == "url":
119+
if isinstance(state[source_type], list):
120+
source.extend(url for url in state[source_type] if isinstance(url, str))
121+
elif isinstance(state[source_type], str):
122+
source.append(state[source_type])
123+
124+
return source_type, source, prompt
125+
126+
def _get_model_info(self, current_node):
127+
"""Extracts LLM and embedder model information from the node."""
128+
llm_model = None
129+
llm_model_name = None
130+
embedder_model = None
108131

109-
Returns:
110-
Tuple[dict, list]: A tuple containing the final state and a list of execution info.
132+
if hasattr(current_node, "llm_model"):
133+
llm_model = current_node.llm_model
134+
if hasattr(llm_model, "model_name"):
135+
llm_model_name = llm_model.model_name
136+
elif hasattr(llm_model, "model"):
137+
llm_model_name = llm_model.model
138+
elif hasattr(llm_model, "model_id"):
139+
llm_model_name = llm_model.model_id
140+
141+
if hasattr(current_node, "embedder_model"):
142+
embedder_model = current_node.embedder_model
143+
if hasattr(embedder_model, "model_name"):
144+
embedder_model = embedder_model.model_name
145+
elif hasattr(embedder_model, "model"):
146+
embedder_model = embedder_model.model
147+
148+
return llm_model, llm_model_name, embedder_model
149+
150+
def _get_schema(self, current_node):
151+
"""Extracts schema information from the node configuration."""
152+
if not hasattr(current_node, "node_config"):
153+
return None
154+
155+
if not isinstance(current_node.node_config, dict):
156+
return None
157+
158+
schema_config = current_node.node_config.get("schema")
159+
if not schema_config or isinstance(schema_config, dict):
160+
return None
161+
162+
try:
163+
return schema_config.schema()
164+
except Exception:
165+
return None
166+
167+
def _execute_node(self, current_node, state, llm_model, llm_model_name):
168+
"""Executes a single node and returns execution information."""
169+
curr_time = time.time()
170+
171+
with self.callback_manager.exclusive_get_callback(llm_model, llm_model_name) as cb:
172+
result = current_node.execute(state)
173+
node_exec_time = time.time() - curr_time
174+
175+
cb_data = None
176+
if cb is not None:
177+
cb_data = {
178+
"node_name": current_node.node_name,
179+
"total_tokens": cb.total_tokens,
180+
"prompt_tokens": cb.prompt_tokens,
181+
"completion_tokens": cb.completion_tokens,
182+
"successful_requests": cb.successful_requests,
183+
"total_cost_USD": cb.total_cost,
184+
"exec_time": node_exec_time,
185+
}
186+
187+
return result, node_exec_time, cb_data
188+
189+
def _get_next_node(self, current_node, result):
190+
"""Determines the next node to execute based on current node type and result."""
191+
if current_node.node_type == "conditional_node":
192+
node_names = {node.node_name for node in self.nodes}
193+
if result in node_names:
194+
return result
195+
elif result is None:
196+
return None
197+
raise ValueError(
198+
f"Conditional Node returned a node name '{result}' that does not exist in the graph"
199+
)
200+
201+
return self.edges.get(current_node.node_name)
202+
203+
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
204+
"""
205+
Executes the graph by traversing nodes starting from the entry point using the standard method.
111206
"""
112207
current_node_name = self.entry_point
113208
state = initial_state
114-
115-
# variables for tracking execution info
209+
210+
# Tracking variables
116211
total_exec_time = 0.0
117212
exec_info = []
118213
cb_total = {
@@ -134,104 +229,51 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
134229
schema = None
135230

136231
while current_node_name:
137-
curr_time = time.time()
138-
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
139-
140-
if current_node.__class__.__name__ == "FetchNode":
141-
source_type = list(state.keys())[1]
142-
if state.get("user_prompt", None):
143-
prompt = state["user_prompt"] if isinstance(state["user_prompt"], str) else None
144-
145-
if source_type == "local_dir":
146-
source_type = "html_dir"
147-
elif source_type == "url":
148-
if isinstance(state[source_type], list):
149-
for url in state[source_type]:
150-
if isinstance(url, str):
151-
source.append(url)
152-
elif isinstance(state[source_type], str):
153-
source.append(state[source_type])
154-
155-
if hasattr(current_node, "llm_model") and llm_model is None:
156-
llm_model = current_node.llm_model
157-
if hasattr(llm_model, "model_name"):
158-
llm_model_name = llm_model.model_name
159-
elif hasattr(llm_model, "model"):
160-
llm_model_name = llm_model.model
161-
elif hasattr(llm_model, "model_id"):
162-
llm_model_name = llm_model.model_id
163-
164-
if hasattr(current_node, "embedder_model") and embedder_model is None:
165-
embedder_model = current_node.embedder_model
166-
if hasattr(embedder_model, "model_name"):
167-
embedder_model = embedder_model.model_name
168-
elif hasattr(embedder_model, "model"):
169-
embedder_model = embedder_model.model
170-
171-
if hasattr(current_node, "node_config"):
172-
if isinstance(current_node.node_config,dict):
173-
if current_node.node_config.get("schema", None) and schema is None:
174-
if not isinstance(current_node.node_config["schema"], dict):
175-
try:
176-
schema = current_node.node_config["schema"].schema()
177-
except Exception as e:
178-
schema = None
179-
180-
with self.callback_manager.exclusive_get_callback(llm_model, llm_model_name) as cb:
181-
try:
182-
result = current_node.execute(state)
183-
except Exception as e:
184-
error_node = current_node.node_name
185-
graph_execution_time = time.time() - start_time
186-
log_graph_execution(
187-
graph_name=self.graph_name,
188-
source=source,
189-
prompt=prompt,
190-
schema=schema,
191-
llm_model=llm_model_name,
192-
embedder_model=embedder_model,
193-
source_type=source_type,
194-
execution_time=graph_execution_time,
195-
error_node=error_node,
196-
exception=str(e)
197-
)
198-
raise e
199-
node_exec_time = time.time() - curr_time
232+
current_node = self._get_node_by_name(current_node_name)
233+
234+
# Update source information if needed
235+
if source_type is None:
236+
source_type, source, prompt = self._update_source_info(current_node, state)
237+
238+
# Get model information if needed
239+
if llm_model is None:
240+
llm_model, llm_model_name, embedder_model = self._get_model_info(current_node)
241+
242+
# Get schema if needed
243+
if schema is None:
244+
schema = self._get_schema(current_node)
245+
246+
try:
247+
result, node_exec_time, cb_data = self._execute_node(
248+
current_node, state, llm_model, llm_model_name
249+
)
200250
total_exec_time += node_exec_time
201251

202-
if cb is not None:
203-
cb_data = {
204-
"node_name": current_node.node_name,
205-
"total_tokens": cb.total_tokens,
206-
"prompt_tokens": cb.prompt_tokens,
207-
"completion_tokens": cb.completion_tokens,
208-
"successful_requests": cb.successful_requests,
209-
"total_cost_USD": cb.total_cost,
210-
"exec_time": node_exec_time,
211-
}
212-
252+
if cb_data:
213253
exec_info.append(cb_data)
214-
215-
cb_total["total_tokens"] += cb_data["total_tokens"]
216-
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
217-
cb_total["completion_tokens"] += cb_data["completion_tokens"]
218-
cb_total["successful_requests"] += cb_data["successful_requests"]
219-
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
220-
221-
if current_node.node_type == "conditional_node":
222-
node_names = {node.node_name for node in self.nodes}
223-
if result in node_names:
224-
current_node_name = result
225-
elif result is None:
226-
current_node_name = None
227-
else:
228-
raise ValueError(f"Conditional Node returned a node name '{result}' that does not exist in the graph")
229-
230-
elif current_node_name in self.edges:
231-
current_node_name = self.edges[current_node_name]
232-
else:
233-
current_node_name = None
234-
254+
for key in cb_total:
255+
cb_total[key] += cb_data[key]
256+
257+
current_node_name = self._get_next_node(current_node, result)
258+
259+
except Exception as e:
260+
error_node = current_node.node_name
261+
graph_execution_time = time.time() - start_time
262+
log_graph_execution(
263+
graph_name=self.graph_name,
264+
source=source,
265+
prompt=prompt,
266+
schema=schema,
267+
llm_model=llm_model_name,
268+
embedder_model=embedder_model,
269+
source_type=source_type,
270+
execution_time=graph_execution_time,
271+
error_node=error_node,
272+
exception=str(e)
273+
)
274+
raise e
275+
276+
# Add total results to execution info
235277
exec_info.append({
236278
"node_name": "TOTAL RESULT",
237279
"total_tokens": cb_total["total_tokens"],
@@ -242,6 +284,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
242284
"exec_time": total_exec_time,
243285
})
244286

287+
# Log final execution results
245288
graph_execution_time = time.time() - start_time
246289
response = state.get("answer", None) if source_type == "url" else None
247290
content = state.get("parsed_doc", None) if response is not None else None
@@ -300,3 +343,4 @@ def append_node(self, node):
300343
self.raw_edges.append((last_node, node))
301344
self.nodes.append(node)
302345
self.edges = self._create_edges({e for e in self.raw_edges})
346+

0 commit comments

Comments
 (0)