|
7 | 7 | from stagehand.a11y.utils import get_accessibility_tree |
8 | 8 | from stagehand.llm.inference import extract as extract_inference |
9 | 9 | from stagehand.metrics import StagehandFunctionName # Changed import location |
10 | | -from stagehand.schemas import DEFAULT_EXTRACT_SCHEMA as DefaultExtractSchema, ExtractOptions, ExtractResult |
| 10 | +from stagehand.schemas import ( |
| 11 | + DEFAULT_EXTRACT_SCHEMA, |
| 12 | + ExtractOptions, |
| 13 | + ExtractResult, |
| 14 | +) |
11 | 15 | from stagehand.utils import inject_urls, transform_url_strings_to_ids |
12 | 16 |
|
13 | 17 | T = TypeVar("T", bound=BaseModel) |
@@ -93,7 +97,7 @@ async def extract( |
93 | 97 | # TODO: Remove this once we have a better way to handle URLs |
94 | 98 | transformed_schema, url_paths = transform_url_strings_to_ids(schema) |
95 | 99 | else: |
96 | | - transformed_schema = DefaultExtractSchema |
| 100 | + transformed_schema = DEFAULT_EXTRACT_SCHEMA |
97 | 101 |
|
98 | 102 | # Use inference to call the LLM |
99 | 103 | extraction_response = extract_inference( |
@@ -149,15 +153,15 @@ async def extract( |
149 | 153 | validated_model_instance = schema.model_validate(raw_data_dict) |
150 | 154 | processed_data_payload = validated_model_instance # Payload is now the Pydantic model instance |
151 | 155 | except Exception as e: |
152 | | - schema_name = getattr(schema, '__name__', str(schema)) |
| 156 | + schema_name = getattr(schema, "__name__", str(schema)) |
153 | 157 | self.logger.error( |
154 | 158 | f"Failed to validate extracted data against schema {schema_name}: {e}. Keeping raw data dict in .data field." |
155 | 159 | ) |
156 | 160 |
|
157 | 161 | # Create ExtractResult object with extracted data as fields |
158 | 162 | if isinstance(processed_data_payload, dict): |
159 | 163 | result = ExtractResult(**processed_data_payload) |
160 | | - elif hasattr(processed_data_payload, 'model_dump'): |
| 164 | + elif hasattr(processed_data_payload, "model_dump"): |
161 | 165 | # For Pydantic models, convert to dict and spread as fields |
162 | 166 | result = ExtractResult(**processed_data_payload.model_dump()) |
163 | 167 | else: |
|
0 commit comments