@@ -582,14 +582,16 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
582
582
583
583
fields : dict [str , tuple [Any , FieldInfo ]] = {}
584
584
585
+ original_model_fields : dict [str , OriginalModelField ] = {}
586
+
585
587
for field_name , field_info in cls .model_fields .items ():
586
588
annotation = field_info .annotation
587
589
assert annotation is not None , f"{ field_name } on invocation { invocation_type } has no type annotation."
588
590
assert isinstance (field_info .json_schema_extra , dict ), (
589
591
f"{ field_name } on invocation { invocation_type } has a non-dict json_schema_extra, did you forget to use InputField?"
590
592
)
591
593
592
- cls . _original_model_fields [field_name ] = OriginalModelField (annotation = annotation , field_info = field_info )
594
+ original_model_fields [field_name ] = OriginalModelField (annotation = annotation , field_info = field_info )
593
595
594
596
validate_field_default (cls .__name__ , field_name , invocation_type , annotation , field_info )
595
597
@@ -676,6 +678,7 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
676
678
docstring = cls .__doc__
677
679
new_class = create_model (cls .__qualname__ , __base__ = cls , __module__ = cls .__module__ , ** fields ) # type: ignore
678
680
new_class .__doc__ = docstring
681
+ new_class ._original_model_fields = original_model_fields
679
682
680
683
InvocationRegistry .register_invocation (new_class )
681
684
0 commit comments