Skip to content

Commit

Permalink
add a message about new torch/xformers version and a way to upgrade b…
Browse files Browse the repository at this point in the history
…y specifying a commandline flag
  • Loading branch information
AUTOMATIC1111 committed Jan 23, 2023
1 parent 56f63cd commit 7ff1ef7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
3 changes: 2 additions & 1 deletion launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def prepare_environment():
sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
Expand All @@ -219,7 +220,7 @@ def prepare_environment():
print(f"Python {sys.version}")
print(f"Commit hash: {commit}")

if not is_installed("torch") or not is_installed("torchvision"):
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")

if not skip_torch_cuda_test:
Expand Down
26 changes: 26 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version

from modules import import_hook, errors, extra_networks
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
Expand Down Expand Up @@ -49,7 +50,32 @@
server_name = "0.0.0.0" if cmd_opts.listen else None


def check_versions():
expected_torch_version = "1.13.1"

if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f"""
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
Beware that this will cause a lot of large files to be downloaded.
""".strip())

expected_xformers_version = "0.0.16rc425"
if shared.xformers_available:
import xformers

if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
errors.print_error_explanation(f"""
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
""".strip())


def initialize():
check_versions()

extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)

Expand Down

0 comments on commit 7ff1ef7

Please sign in to comment.