Skip to content

Commit

Permalink
Add xformers to training scripts (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
hafriedlander authored Dec 30, 2022
1 parent 7dd0467 commit 4936d0f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 4 deletions.
70 changes: 70 additions & 0 deletions lora_diffusion/xformers_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import functools

import torch
from diffusers.models.attention import BasicTransformerBlock
from diffusers.utils.import_utils import is_xformers_available

from .lora import LoraInjectedLinear

if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


@functools.cache
def test_xformers_backwards(size):
@torch.enable_grad()
def _grad(size):
q = torch.randn((1, 4, size), device="cuda")
k = torch.randn((1, 4, size), device="cuda")
v = torch.randn((1, 4, size), device="cuda")

q = q.detach().requires_grad_()
k = k.detach().requires_grad_()
v = v.detach().requires_grad_()

out = xformers.ops.memory_efficient_attention(q, k, v)
loss = out.sum(2).mean(0).sum()

return torch.autograd.grad(loss, v)

try:
_grad(size)
print(size, "pass")
return True
except Exception as e:
print(size, "fail")
return False


def set_use_memory_efficient_attention_xformers(
module: torch.nn.Module, valid: bool
) -> None:
def fn_test_dim_head(module: torch.nn.Module):
if isinstance(module, BasicTransformerBlock):
# dim_head isn't stored anywhere, so back-calculate
source = module.attn1.to_v
if isinstance(source, LoraInjectedLinear):
source = source.linear

dim_head = source.out_features // module.attn1.heads

result = test_xformers_backwards(dim_head)

# If dim_head > dim_head_max, turn xformers off
if not result:
module.set_use_memory_efficient_attention_xformers(False)

for child in module.children():
fn_test_dim_head(child)

if not is_xformers_available() and valid:
print("XFormers is not available. Skipping.")
return

module.set_use_memory_efficient_attention_xformers(valid)

if valid:
fn_test_dim_head(module)
11 changes: 9 additions & 2 deletions train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
save_lora_weight,
save_safeloras,
)

from torch.utils.data import Dataset
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from pathlib import Path
Expand Down Expand Up @@ -450,6 +450,9 @@ def parse_args(input_args=None):
required=False,
help="Should images be resized to --resolution before training?",
)
parser.add_argument(
"--use_xformers", action="store_true", help="Whether or not to use xformers"
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -615,6 +618,10 @@ def main(args):
)
break

if args.use_xformers:
set_use_memory_efficient_attention_xformers(unet, True)
set_use_memory_efficient_attention_xformers(vae, True)

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
Expand Down
11 changes: 9 additions & 2 deletions train_lora_w_ti.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
save_lora_weight,
extract_lora_ups_down,
)

from torch.utils.data import Dataset
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from pathlib import Path
Expand Down Expand Up @@ -575,6 +575,9 @@ def parse_args(input_args=None):
required=False,
help="Should images be resized to --resolution before training?",
)
parser.add_argument(
"--use_xformers", action="store_true", help="Whether or not to use xformers"
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -774,6 +777,10 @@ def main(args):
print("Before training: text encoder First Layer lora down", _down.weight.data)
break

if args.use_xformers:
set_use_memory_efficient_attention_xformers(unet, True)
set_use_memory_efficient_attention_xformers(vae, True)

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
Expand Down

1 comment on commit 4936d0f

@hdon96
Copy link
Contributor

@hdon96 hdon96 commented on 4936d0f Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble running with --use_xformers (Windows).
The error seems to be with how set_use_memory_efficient_attention_xformers is setup.

Traceback (most recent call last):
File "C:\Users<user>\lora\train_lora_dreambooth.py", line 1039, in
main(args)
File "C:\Users<user>\lora\train_lora_dreambooth.py", line 655, in main
set_use_memory_efficient_attention_xformers(unet, True)
File "C:\Users<user>\lora\lora_diffusion\xformers_utils.py", line 67, in set_use_memory_efficient_attention_xformers
module.set_use_memory_efficient_attention_xformers(valid)
File "F:\ANACONDA\envs\sd\lib\site-packages\torch\nn\modules\module.py", line 1207, in getattr
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'UNet2DConditionModel' object has no attribute 'set_use_memory_efficient_attention_xformers'

Please sign in to comment.