Skip to content

Commit

Permalink
Merge branch 'master' into remove-watermark-option
Browse files Browse the repository at this point in the history
  • Loading branch information
space-nuko authored Mar 27, 2023
2 parents d86beb8 + 955df77 commit 0826130
Show file tree
Hide file tree
Showing 48 changed files with 1,332 additions and 1,063 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
cache-dependency-path: |
**/requirements*txt
- name: Run tests
run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3
if: always()
Expand Down
21 changes: 10 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ A browser interface based on Gradio library for Stable Diffusion.
- Prompt Matrix
- Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to
- a man in a ((tuxedo)) - will pay more attention to tuxedo
- a man in a (tuxedo:1.21) - alternative syntax
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
- a man in a `((tuxedo))` - will pay more attention to tuxedo
- a man in a `(tuxedo:1.21)` - alternative syntax
- select text and press `Ctrl+Up` or `Ctrl+Down` to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion
Expand All @@ -28,7 +28,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler
- ESRGAN, neural network upscaler with a lot of third party models
- SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
- SwinIR and Swin2SR ([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
- LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options
- Sampling method selection
Expand All @@ -46,7 +46,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- drag and drop an image/text-parameters to promptbox
- Read Generation Parameters Button, loads parameters in promptbox to UI
- Settings page
- Running arbitrary python code from UI (must run with --allow-code to enable)
- Running arbitrary python code from UI (must run with `--allow-code` to enable)
- Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config
- Tiling support, a checkbox to create images that can be tiled like textures
Expand All @@ -69,7 +69,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add `--xformers` to commandline args)
- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
- Generate forever option
- Training tab
Expand All @@ -78,11 +78,11 @@ A browser interface based on Gradio library for Stable Diffusion.
- Clip skip
- Hypernetworks
- Loras (same as Hypernetworks but more pretty)
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt.
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
- Can select to load a different VAE from settings screen
- Estimated completion time in progress bar
- API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
Expand All @@ -91,7 +91,6 @@ A browser interface based on Gradio library for Stable Diffusion.
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
- Now with a license!
- Reorder elements in the UI from settings screen
-

## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
Expand All @@ -101,7 +100,7 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)

### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
Expand Down Expand Up @@ -159,4 +158,4 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
- (You)
202 changes: 169 additions & 33 deletions extensions-builtin/Lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,34 @@
import os
import re
import torch
from typing import Union

from modules import shared, devices, sd_models, errors

metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}

re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_compiled = {}

suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
}


def convert_diffusers_name_to_compvis(key, is_sd2):
def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex


def convert_diffusers_name_to_compvis(key):
def match(match_list, regex):
r = re.match(regex, key)
if not r:
return False
Expand All @@ -26,16 +40,33 @@ def match(match_list, regex):

m = []

if match(m, re_unet_down_blocks):
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"

if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"

if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"

if match(m, re_unet_mid_blocks):
return f"diffusion_model_middle_block_1_{m[1]}"
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"

if match(m, re_unet_up_blocks):
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"

if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
if is_sd2:
if 'mlp_fc1' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
elif 'mlp_fc2' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
else:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"

if match(m, re_text_block):
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"

return key
Expand Down Expand Up @@ -101,15 +132,22 @@ def load_lora(name, filename):

sd = sd_models.read_state_dict(filename)

keys_failed_to_match = []
keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping

for key_diffusers, weight in sd.items():
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
key, lora_key = fullkey.split(".", 1)
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)

sd_module = shared.sd_model.lora_layer_mapping.get(key, None)

if sd_module is None:
keys_failed_to_match.append(key_diffusers)
m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)

if sd_module is None:
keys_failed_to_match[key_diffusers] = key
continue

lora_module = lora.modules.get(key, None)
Expand All @@ -123,15 +161,21 @@ def load_lora(name, filename):

if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'

with torch.no_grad():
module.weight.copy_(weight)

module.to(device=devices.device, dtype=devices.dtype)
module.to(device=devices.cpu, dtype=devices.dtype)

if lora_key == "lora_up.weight":
lora_module.up = module
Expand Down Expand Up @@ -177,28 +221,120 @@ def load_loras(names, multipliers=None):
loaded_loras.append(lora)


def lora_forward(module, input, res):
if len(loaded_loras) == 0:
return res
def lora_calc_updown(lora, module, target):
with torch.no_grad():
up = module.up.weight.to(target.device, dtype=target.dtype)
down = module.down.weight.to(target.device, dtype=target.dtype)

lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
updown = up @ down

updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)

return updown


def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of Loras to the weights of torch layer self.
If weights already have this particular set of loras applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras.
"""

lora_layer_name = getattr(self, 'lora_layer_name', None)
if lora_layer_name is None:
return

current_names = getattr(self, "lora_current_names", ())
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)

weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None:
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:
weights_backup = self.weight.to(devices.cpu, copy=True)

self.lora_weights_backup = weights_backup

if current_names != wanted_names:
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
self.weight.copy_(weights_backup)

return res
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None and hasattr(self, 'weight'):
self.weight += lora_calc_updown(lora, module, self.weight)
continue

module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)

if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])

self.in_proj_weight += updown_qkv
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
continue

if module is None:
continue

print(f'failed to calculate lora weights for layer {lora_layer_name}')

setattr(self, "lora_current_names", wanted_names)


def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)


def lora_Linear_forward(self, input):
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
lora_apply_weights(self)

return torch.nn.Linear_forward_before_lora(self, input)


def lora_Linear_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)

return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)


def lora_Conv2d_forward(self, input):
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
lora_apply_weights(self)

return torch.nn.Conv2d_forward_before_lora(self, input)


def lora_Conv2d_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)

return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)


def lora_MultiheadAttention_forward(self, *args, **kwargs):
lora_apply_weights(self)

return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)


def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)

return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)


def list_available_loras():
Expand All @@ -211,7 +347,7 @@ def list_available_loras():
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)

for filename in sorted(candidates):
for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename):
continue

Expand Down
Loading

0 comments on commit 0826130

Please sign in to comment.