Skip to content

Commit

Permalink
fix flash_attn, make sdpa default
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Dec 6, 2024
1 parent d0b98c8 commit 2e8f1f2
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 30 deletions.
2 changes: 1 addition & 1 deletion examples/hyvideo_lowvram_blockswap_test.json
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
"bf16",
"fp8_e4m3fn",
"offload_device",
"sageattn_varlen"
"sdpa"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/hyvideo_t2v_example_01.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"bf16",
"fp8_e4m3fn",
"offload_device",
"sageattn_varlen"
"sdpa"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/hyvideo_v2v_example_01.json
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@
"bf16",
"fp8_e4m3fn",
"offload_device",
"sageattn_varlen"
"sdpa"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions hyvideo/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def attention(
k,
v,
heads,
mode="flash_attn",
mode="sdpa",
drop_rate=0,
attn_mask=None,
causal=False,
Expand Down Expand Up @@ -142,7 +142,7 @@ def attention(
elif mode == "comfy":
x = optimized_attention(q, k, v, mask=attn_mask, heads=heads, skip_reshape=True)

elif mode == "flash_attn":
elif mode == "flash_attn_varlen":
x = flash_attn_varlen_func(
q,
k,
Expand Down
47 changes: 25 additions & 22 deletions hyvideo/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
attention_mode: str = "flash_attn",
attention_mode: str = "sdpa",
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
Expand Down Expand Up @@ -199,9 +199,9 @@ def forward(
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
assert (
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
#assert (
# cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
#), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
attn = attention(
q,
k,
Expand Down Expand Up @@ -262,7 +262,7 @@ def __init__(
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
attention_mode: str = "flash_attn",
attention_mode: str = "sdpa",
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
Expand Down Expand Up @@ -353,9 +353,9 @@ def forward(
k = torch.cat((img_k, txt_k), dim=1)

# Compute attention.
assert (
cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
#assert (
# cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
#), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
attn = attention(
q,
k,
Expand Down Expand Up @@ -452,7 +452,7 @@ def __init__(
device: Optional[torch.device] = None,
main_device: Optional[torch.device] = None,
offload_device: Optional[torch.device] = None,
attention_mode: str = "flash_attn",
attention_mode: str = "sdpa",
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
Expand All @@ -466,6 +466,7 @@ def __init__(

self.main_device = main_device
self.offload_device = offload_device
self.attention_mode = attention_mode

# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
Expand Down Expand Up @@ -651,22 +652,24 @@ def forward(

txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len

# Compute cu_squlens and max_seqlen for flash attention
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q
max_seqlen_q = img_seq_len + txt_seq_len
max_seqlen_kv = max_seqlen_q

# Create a square boolean mask filled with False
attn_mask = torch.zeros((1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
if self.attention_mode == "sdpa" or self.attention_mode == "comfy":
cu_seqlens_q, cu_seqlens_kv = None, None
# Create a square boolean mask filled with False
attn_mask = torch.zeros((1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)

# Calculate the valid attention regions
text_len = text_mask[0].sum().item()
total_len = text_len + img_seq_len
# Calculate the valid attention regions
text_len = text_mask[0].sum().item()
total_len = text_len + img_seq_len

# Allow attention to all tokens up to total_len
attn_mask[0, :total_len, :total_len] = True
# Allow attention to all tokens up to total_len
attn_mask[0, :total_len, :total_len] = True
else:
attn_mask = None
# Compute cu_squlens for flash attention
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q

freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
Expand Down
8 changes: 5 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## WORK IN PROGRESS

# Update

Scaled dot product attention (sdpa) should now be working, sageattention is still recommended for speed, but should not be necessary anymore making installation much easier.

Vid2vid test:
[source video](https://www.pexels.com/video/a-4x4-vehicle-speeding-on-a-dirt-road-during-a-competition-15604814/)

Expand All @@ -13,8 +17,6 @@ text2vid (old test):
https://github.com/user-attachments/assets/3750da65-9753-4bd2-aae2-a688d2b86115


**Currently seems to require flash_attn (default) or sageattn, spda is not working**

Transformer and VAE (single files, no autodownload):

https://huggingface.co/Kijai/HunyuanVideo_comfy/tree/main
Expand All @@ -29,7 +31,7 @@ Files go to `ComfyUI/models/LLM/llava-llama-3-8b-text-encoder-tokenizer`

Clip text encoder (has autodownload)

For now using the original https://huggingface.co/openai/clip-vit-large-patch14, files (only need the .safetensor from the weights) go to:
For now using the original https://huggingface.co/openai/clip-vit-large-patch14, files (only need the .safetensor from the weights and all the config files) go to:

`ComfyUI/models/clip/clip-vit-large-patch14`

Expand Down

0 comments on commit 2e8f1f2

Please sign in to comment.