Skip to content

Commit 8d012c5

Browse files
fix(app): address pydantic deprecation warning for accessing BaseModel.model_fields
1 parent 2b1e4b8 commit 8d012c5

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def get_type(cls) -> str:
177177
return cls.model_fields["type"].default
178178

179179
@classmethod
180-
def get_output_annotation(cls) -> BaseInvocationOutput:
180+
def get_output_annotation(cls) -> Type[BaseInvocationOutput]:
181181
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
182182
return signature(cls.invoke).return_annotation
183183

@@ -209,7 +209,7 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
209209
Internal invoke method, calls `invoke()` after some prep.
210210
Handles optional fields that are required to call `invoke()` and invocation cache.
211211
"""
212-
for field_name, field in self.model_fields.items():
212+
for field_name, field in type(self).model_fields.items():
213213
if not field.json_schema_extra or callable(field.json_schema_extra):
214214
# something has gone terribly awry, we should always have this and it should be a dict
215215
continue
@@ -224,9 +224,9 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
224224
setattr(self, field_name, orig_default)
225225
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
226226
if input_ == Input.Connection:
227-
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
227+
raise RequiredConnectionException(type(self).model_fields["type"].default, field_name)
228228
elif input_ == Input.Any:
229-
raise MissingInputException(self.model_fields["type"].default, field_name)
229+
raise MissingInputException(type(self).model_fields["type"].default, field_name)
230230

231231
# skip node cache codepath if it's disabled
232232
if services.configuration.node_cache_size == 0:

invokeai/app/services/session_queue/session_queue_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def validate_batch_nodes_and_edges(cls, values):
148148
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
149149
except NodeNotFoundError:
150150
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
151-
if batch_data.field_name not in node.model_fields:
151+
if batch_data.field_name not in type(node).model_fields:
152152
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
153153
return values
154154

invokeai/app/services/shared/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def validate_self(self) -> None:
424424
)
425425

426426
# input fields are on the node
427-
if edge.destination.field not in destination_node.model_fields:
427+
if edge.destination.field not in type(destination_node).model_fields:
428428
raise NodeFieldNotFoundError(
429429
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
430430
)

0 commit comments

Comments
 (0)