Skip to content

Commit d6f8c82

Browse files
feat(nodes): store original field annotation & FieldInfo in invocations
1 parent 76a0798 commit d6f8c82

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Literal,
2323
Optional,
2424
Type,
25+
TypedDict,
2526
TypeVar,
2627
Union,
2728
)
@@ -106,6 +107,11 @@ class UIConfigBase(BaseModel):
106107
)
107108

108109

110+
class OriginalModelField(TypedDict):
111+
annotation: Any
112+
field_info: FieldInfo
113+
114+
109115
class BaseInvocationOutput(BaseModel):
110116
"""
111117
Base class for all invocation outputs.
@@ -134,6 +140,9 @@ def get_type(cls) -> str:
134140
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
135141
return cls.model_fields["type"].default
136142

143+
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
144+
"""The original model fields, before any modifications were made by the @invocation_output decorator."""
145+
137146
model_config = ConfigDict(
138147
protected_namespaces=(),
139148
validate_assignment=True,
@@ -266,6 +275,9 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
266275
coerce_numbers_to_str=True,
267276
)
268277

278+
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
279+
"""The original model fields, before any modifications were made by the @invocation decorator."""
280+
269281

270282
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
271283

@@ -575,6 +587,8 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
575587
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
576588
)
577589

590+
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
591+
578592
validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
579593

580594
if field_info.default is None and not is_optional(annotation):
@@ -686,6 +700,9 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
686700
assert isinstance(field_info.json_schema_extra, dict), (
687701
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
688702
)
703+
704+
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
705+
689706
if field_info.default is not PydanticUndefined and is_optional(annotation):
690707
annotation = annotation | None
691708
fields[field_name] = (annotation, field_info)

0 commit comments

Comments
 (0)