Skip to content

Commit 79f7b8e

Browse files
committed
Split run_app.py and api_app.py so that api_app.py is more narrowly responsible for just initializing the FastAPI app. This also gives clearer control over the order of the initialization steps, which will be important as we add planned torch configurations that must be applied before torch is imported.
1 parent f7d15c8 commit 79f7b8e

File tree

2 files changed

+62
-53
lines changed

2 files changed

+62
-53
lines changed

invokeai/app/api_app.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from contextlib import asynccontextmanager
44
from pathlib import Path
55

6-
import uvicorn
76
from fastapi import FastAPI, Request
87
from fastapi.middleware.cors import CORSMiddleware
98
from fastapi.middleware.gzip import GZipMiddleware
@@ -31,26 +30,11 @@
3130
from invokeai.app.api.sockets import SocketIO
3231
from invokeai.app.services.config.config_default import get_config
3332
from invokeai.app.util.custom_openapi import get_openapi_func
34-
from invokeai.app.util.startup_utils import (
35-
apply_monkeypatches,
36-
check_cudnn,
37-
enable_dev_reload,
38-
find_open_port,
39-
register_mime_types,
40-
)
41-
from invokeai.backend.util.devices import TorchDevice
4233
from invokeai.backend.util.logging import InvokeAILogger
4334

4435
app_config = get_config()
45-
46-
apply_monkeypatches()
47-
register_mime_types()
48-
4936
logger = InvokeAILogger.get_logger(config=app_config)
5037

51-
torch_device_name = TorchDevice.get_torch_device_name()
52-
logger.info(f"Using torch device: {torch_device_name}")
53-
5438
loop = asyncio.new_event_loop()
5539

5640

@@ -176,34 +160,3 @@ def overridden_redoc() -> HTMLResponse:
176160
app.mount(
177161
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
178162
) # docs favicon is in here
179-
180-
181-
def invoke_api() -> None:
182-
if app_config.dev_reload:
183-
enable_dev_reload()
184-
185-
orig_config_port = app_config.port
186-
app_config.port = find_open_port(app_config.port)
187-
if orig_config_port != app_config.port:
188-
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
189-
190-
check_cudnn(logger)
191-
192-
config = uvicorn.Config(
193-
app=app,
194-
host=app_config.host,
195-
port=app_config.port,
196-
loop="asyncio",
197-
log_level=app_config.log_level_network,
198-
ssl_certfile=app_config.ssl_certfile,
199-
ssl_keyfile=app_config.ssl_keyfile,
200-
)
201-
server = uvicorn.Server(config)
202-
203-
# replace uvicorn's loggers with InvokeAI's for consistent appearance
204-
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
205-
uvicorn_logger.handlers.clear()
206-
for hdlr in logger.handlers:
207-
uvicorn_logger.addHandler(hdlr)
208-
209-
loop.run_until_complete(server.serve())

invokeai/app/run_app.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,68 @@
1-
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
1+
import uvicorn
22

3+
from invokeai.app.services.config.config_default import get_config
4+
from invokeai.app.util.startup_utils import (
5+
apply_monkeypatches,
6+
check_cudnn,
7+
enable_dev_reload,
8+
find_open_port,
9+
register_mime_types,
10+
)
11+
from invokeai.backend.util.logging import InvokeAILogger
12+
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
313

4-
def run_app() -> None:
5-
# Before doing _anything_, parse CLI args!
6-
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
714

15+
def get_app():
16+
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
17+
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
18+
"""
19+
from invokeai.app.api_app import app, loop
20+
21+
return app, loop
22+
23+
24+
def run_app() -> None:
25+
"""The main entrypoint for the app."""
26+
# Parse the CLI arguments.
827
InvokeAIArgs.parse_args()
928

10-
from invokeai.app.api_app import invoke_api
29+
# Load config.
30+
app_config = get_config()
31+
32+
logger = InvokeAILogger.get_logger(config=app_config)
33+
34+
# Find an open port, and modify the config accordingly.
35+
orig_config_port = app_config.port
36+
app_config.port = find_open_port(app_config.port)
37+
if orig_config_port != app_config.port:
38+
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
39+
40+
# Miscellaneous startup tasks.
41+
apply_monkeypatches()
42+
register_mime_types()
43+
if app_config.dev_reload:
44+
enable_dev_reload()
45+
check_cudnn(logger)
46+
47+
# Initialize the app and event loop.
48+
app, loop = get_app()
49+
50+
# Start the server.
51+
config = uvicorn.Config(
52+
app=app,
53+
host=app_config.host,
54+
port=app_config.port,
55+
loop="asyncio",
56+
log_level=app_config.log_level_network,
57+
ssl_certfile=app_config.ssl_certfile,
58+
ssl_keyfile=app_config.ssl_keyfile,
59+
)
60+
server = uvicorn.Server(config)
61+
62+
# replace uvicorn's loggers with InvokeAI's for consistent appearance
63+
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
64+
uvicorn_logger.handlers.clear()
65+
for hdlr in logger.handlers:
66+
uvicorn_logger.addHandler(hdlr)
1167

12-
invoke_api()
68+
loop.run_until_complete(server.serve())

0 commit comments

Comments
 (0)