Skip to content

Commit

Permalink
update torch
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic committed Mar 30, 2023
1 parent 22bcc7b commit d5063e0
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ notification.mp3
/extensions
/test/stdout.txt
/test/stderr.txt
/cache.json
/cache.json*
11 changes: 6 additions & 5 deletions environment-wsl2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ channels:
- defaults
dependencies:
- python=3.10
- pip=22.2.2
- cudatoolkit=11.3
- pytorch=1.12.1
- torchvision=0.13.1
- numpy=1.23.1
- pip=23.0
- cudatoolkit=11.8
- pytorch=2.0
- torchvision=0.15
- numpy=1.23

6 changes: 3 additions & 3 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ def run_extensions_installers(settings_file):
def prepare_environment():
global skip_install

torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")

xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
Expand Down Expand Up @@ -296,7 +296,7 @@ def prepare_environment():

if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
run_pip(f"install -r \"{requirements_file}\"", "requirements")

run_extensions_installers(settings_file=args.ui_settings_file)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
astunparse
blendmodes
accelerate
basicsr
Expand Down
4 changes: 2 additions & 2 deletions requirements_versions.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
blendmodes==2022
transformers==4.25.1
accelerate==0.12.0
accelerate==0.18.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.23
numpy==1.23.3
numpy==1.23.5
Pillow==9.4.0
realesrgan==0.3.0
torch
Expand Down
2 changes: 1 addition & 1 deletion webui-macos-env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fi

export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1
Expand Down
2 changes: 2 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")

startup_timer.record("import torch")

import gradio
Expand Down

1 comment on commit d5063e0

@zhengcr123
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good work,great!

Please sign in to comment.