Skip to content

Commit

Permalink
run basic torch calculation at startup in parallel to reduce the perf…
Browse files Browse the repository at this point in the history
…ormance impact of first generation
  • Loading branch information
AUTOMATIC1111 committed May 21, 2023
1 parent 1f31829 commit 8faac8b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
18 changes: 18 additions & 0 deletions modules/devices.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
import contextlib
from functools import lru_cache

import torch
from modules import errors

Expand Down Expand Up @@ -154,3 +156,19 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check."

raise NansException(message)


@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""

x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)

x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
4 changes: 3 additions & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

from modules import paths, timer, import_hook, errors # noqa: F401
from modules import paths, timer, import_hook, errors, devices # noqa: F401

startup_timer = timer.Timer()

Expand Down Expand Up @@ -295,6 +295,8 @@ def initialize_rest(*, reload_script_modules=False):
# (when reloading, this does nothing)
Thread(target=lambda: shared.sd_model).start()

Thread(target=devices.first_time_calculation).start()

shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")

Expand Down

0 comments on commit 8faac8b

Please sign in to comment.