Skip to content

Commit

Permalink
Add support for Lion8bit
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed May 3, 2023
1 parent 71f9181 commit 111527b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 20 deletions.
1 change: 1 addition & 0 deletions library/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def gradio_training(
'Adafactor',
'DAdaptation',
'Lion',
'Lion8bit',
'SGDNesterov',
'SGDNesterov8bit',
],
Expand Down
11 changes: 11 additions & 0 deletions library/convert_model_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def convert_model(
target_model_name_input,
target_model_type,
target_save_precision_type,
unet_use_linear_projection,
):
# Check for caption_text_input
if source_model_type == '':
Expand Down Expand Up @@ -67,6 +68,14 @@ def convert_model(

if target_model_type == 'diffuser_safetensors':
run_cmd += ' --use_safetensors'

# Fix for stabilityAI diffusers format. When saving v2 models in Diffusers format in training scripts and conversion scripts,
# it was found that the U-Net configuration is different from those of Hugging Face's stabilityai models (this repository is
# "use_linear_projection": false, stabilityai is true). Please note that the weight shapes are different, so please be careful
# when using the weight files directly.

if unet_use_linear_projection:
run_cmd += ' --unet_use_linear_projection'

run_cmd += f' "{source_model_input}"'

Expand Down Expand Up @@ -230,6 +239,7 @@ def gradio_convert_model_tab():
choices=['unspecified', 'fp16', 'bf16', 'float'],
value='unspecified',
)
unet_use_linear_projection = gr.Checkbox(label="UNet linear projection", value=False, info="Enable for Hugging Face's stabilityai models")

convert_button = gr.Button('Convert model')

Expand All @@ -242,6 +252,7 @@ def gradio_convert_model_tab():
target_model_name_input,
target_model_type,
target_save_precision_type,
unet_use_linear_projection,
],
show_progress=False,
)
9 changes: 6 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
accelerate==0.15.0
# Some comments
accelerate==0.18.0
albumentations==1.3.0
altair==4.2.2
bitsandbytes==0.35.0; sys_platform == 'win32'
https://github.com/bmaltais/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl; sys_platform == 'win32'
# This next line is not an error but rather there to properly catch if the url based bitsandbytes was properly installed by the line above...
bitsandbytes==0.38.1; sys_platform == 'win32'
bitsandbytes==0.38.1; (sys_platform == "darwin" or sys_platform == "linux")
dadaptation==1.5
diffusers[torch]==0.10.2
easygui==0.98.3
einops==0.6.0
ftfy==6.1.1
gradio==3.27.0; sys_platform != 'darwin'
gradio==3.28.1; sys_platform != 'darwin'
gradio==3.23.0; sys_platform == 'darwin'
lion-pytorch==0.0.6
opencv-python==4.7.0.68
Expand Down
14 changes: 11 additions & 3 deletions setup.bat
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ echo [1] - v1 (torch 1.12.1) (Recommended)
echo [2] - v2 (torch 2.0.0) (Experimental)
set /p choice="Enter your choice (1 or 2): "


:: Only does this section to cleanup the old custom dll versions that we used to use. No longer needed now with the new bitsandbytes version
pip uninstall -y bitsandbytes
IF EXIST ".\venv\Lib\site-packages\bitsandbytes" (
rmdir .\venv\Lib\site-packages\bitsandbytes
)
:::::::::::::::::::::::::

if %choice%==1 (
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --use-pep517 --upgrade -r requirements.txt
Expand All @@ -41,8 +49,8 @@ if %choice%==1 (
pip install https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/triton-2.0.0-cp310-cp310-win_amd64.whl
)

copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
@REM copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
@REM copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
@REM copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py

accelerate config
41 changes: 27 additions & 14 deletions tools/validate_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import sys
import pkg_resources
import argparse
from packaging.requirements import Requirement
from packaging.markers import default_environment
import re

# Parse command line arguments
parser = argparse.ArgumentParser(description="Validate that requirements are satisfied.")
Expand All @@ -17,27 +20,37 @@
# Check each requirement against the installed packages
missing_requirements = []
wrong_version_requirements = []

url_requirement_pattern = re.compile(r"(?P<url>https?://.+);?\s?(?P<marker>.+)?")

for requirement in requirements:
requirement = requirement.strip()
if requirement == ".":
# Skip the current requirement if it is a dot (.)
continue

url_match = url_requirement_pattern.match(requirement)
if url_match:
if url_match.group("marker"):
marker = url_match.group("marker")
parsed_marker = Marker(marker)
if not parsed_marker.evaluate(default_environment()):
continue
requirement = url_match.group("url")

try:
pkg_resources.require(requirement)
parsed_req = Requirement(requirement)

# Check if the requirement has an environment marker and if it evaluates to False
if parsed_req.marker and not parsed_req.marker.evaluate(default_environment()):
continue

pkg_resources.require(str(parsed_req))
except ValueError:
# This block will handle URL-based requirements
pass
except pkg_resources.DistributionNotFound:
# Check if the requirement contains a VCS URL
if "@" in requirement:
# If it does, split the requirement into two parts: the package name and the VCS URL
package_name, vcs_url = requirement.split("@", 1)
# Use pip to install the package from the VCS URL
os.system(f"pip install -e {vcs_url}")
# Try to require the package again
try:
pkg_resources.require(package_name)
except pkg_resources.DistributionNotFound:
missing_requirements.append(requirement)
else:
missing_requirements.append(requirement)
missing_requirements.append(requirement)
except pkg_resources.VersionConflict as e:
wrong_version_requirements.append((requirement, str(e.req), e.dist.version))

Expand Down

0 comments on commit 111527b

Please sign in to comment.