|
22 | 22 | Literal,
|
23 | 23 | Optional,
|
24 | 24 | Type,
|
| 25 | + TypedDict, |
25 | 26 | TypeVar,
|
26 | 27 | Union,
|
27 | 28 | )
|
@@ -106,6 +107,11 @@ class UIConfigBase(BaseModel):
|
106 | 107 | )
|
107 | 108 |
|
108 | 109 |
|
| 110 | +class OriginalModelField(TypedDict): |
| 111 | + annotation: Any |
| 112 | + field_info: FieldInfo |
| 113 | + |
| 114 | + |
109 | 115 | class BaseInvocationOutput(BaseModel):
|
110 | 116 | """
|
111 | 117 | Base class for all invocation outputs.
|
@@ -134,6 +140,9 @@ def get_type(cls) -> str:
|
134 | 140 | """Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
|
135 | 141 | return cls.model_fields["type"].default
|
136 | 142 |
|
| 143 | + _original_model_fields: ClassVar[dict[str, OriginalModelField]] = {} |
| 144 | + """The original model fields, before any modifications were made by the @invocation_output decorator.""" |
| 145 | + |
137 | 146 | model_config = ConfigDict(
|
138 | 147 | protected_namespaces=(),
|
139 | 148 | validate_assignment=True,
|
@@ -266,6 +275,9 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
|
266 | 275 | coerce_numbers_to_str=True,
|
267 | 276 | )
|
268 | 277 |
|
| 278 | + _original_model_fields: ClassVar[dict[str, OriginalModelField]] = {} |
| 279 | + """The original model fields, before any modifications were made by the @invocation decorator.""" |
| 280 | + |
269 | 281 |
|
270 | 282 | TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
271 | 283 |
|
@@ -575,6 +587,8 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
575 | 587 | f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
576 | 588 | )
|
577 | 589 |
|
| 590 | + cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info) |
| 591 | + |
578 | 592 | validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
|
579 | 593 |
|
580 | 594 | if field_info.default is None and not is_optional(annotation):
|
@@ -686,6 +700,9 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
|
686 | 700 | assert isinstance(field_info.json_schema_extra, dict), (
|
687 | 701 | f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
688 | 702 | )
|
| 703 | + |
| 704 | + cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info) |
| 705 | + |
689 | 706 | if field_info.default is not PydanticUndefined and is_optional(annotation):
|
690 | 707 | annotation = annotation | None
|
691 | 708 | fields[field_name] = (annotation, field_info)
|
|
0 commit comments