Skip to content

Commit 83948c7

Browse files
feat(app): move output annotation checking to run_app
Also change import order to ensure CLI args are handled correctly. Had to do this bc importing `InvocationRegistry` before parsing args resulted in the `--root` CLI arg being ignored.
1 parent 90472c0 commit 83948c7

File tree

3 files changed

+31
-16
lines changed

3 files changed

+31
-16
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,16 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
643643

644644
fields["type"] = (invocation_type_annotation, invocation_type_field_info)
645645

646+
# Invocation outputs must be registered using the @invocation_output decorator, but it is possible that the
647+
# output is registered _after_ this invocation is registered. It depends on module import ordering.
648+
#
649+
# We can only confirm the output for an invocation is registered after all modules are imported. There's
650+
# only really one good time to do that - during application startup, in `run_app.py`, after loading all
651+
# custom nodes.
652+
#
653+
# We can still do some basic validation here - ensure the invoke method is defined and returns an instance
654+
# of BaseInvocationOutput.
655+
646656
# Validate the `invoke()` method is implemented
647657
if "invoke" in cls.__abstractmethods__:
648658
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')

invokeai/app/run_app.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
import uvicorn
2-
3-
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
4-
from invokeai.app.services.config.config_default import get_config
5-
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
6-
from invokeai.backend.util.logging import InvokeAILogger
7-
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
8-
9-
101
def get_app():
112
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
123
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
@@ -18,9 +9,20 @@ def get_app():
189

1910
def run_app() -> None:
2011
"""The main entrypoint for the app."""
21-
# Parse the CLI arguments.
12+
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
13+
14+
# Parse the CLI arguments before doing anything else, which ensures CLI args correctly override settings from other
15+
# sources like `invokeai.yaml` or env vars.
2216
InvokeAIArgs.parse_args()
2317

18+
import uvicorn
19+
20+
from invokeai.app.invocations.baseinvocation import InvocationRegistry
21+
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
22+
from invokeai.app.services.config.config_default import get_config
23+
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
24+
from invokeai.backend.util.logging import InvokeAILogger
25+
2426
# Load config.
2527
app_config = get_config()
2628

@@ -66,6 +68,15 @@ def run_app() -> None:
6668
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
6769
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
6870

71+
# Check all invocations and ensure their outputs are registered.
72+
for invocation in InvocationRegistry.get_invocation_classes():
73+
invocation_type = invocation.get_type()
74+
output_annotation = invocation.get_output_annotation()
75+
if output_annotation not in InvocationRegistry.get_output_classes():
76+
logger.warning(
77+
f'Invocation "{invocation_type}" has unregistered output class "{output_annotation.__name__}"'
78+
)
79+
6980
if app_config.dev_reload:
7081
# load_custom_nodes seems to bypass jurrigged's import sniffer, so be sure to call it *after* they're already
7182
# imported.

invokeai/app/util/custom_openapi.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ def openapi() -> dict[str, Any]:
8888
invocation_output_map_properties[invocation_type] = json_schema["output"]
8989
invocation_output_map_required.append(invocation_type)
9090

91-
output_annotation = invocation.get_output_annotation()
92-
if output_annotation not in InvocationRegistry.get_output_classes():
93-
logger.warning(
94-
f'Invocation "{invocation_type}"\' has unregistered output class {output_annotation.__name__} (did you forget @invocation_output?)'
95-
)
96-
9791
# Add the output map to the schema
9892
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
9993
"type": "object",

0 commit comments

Comments
 (0)