Skip to content

PAG variant for HunyuanDiT, PAG refactor #8936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2771b98
copy hunyuandit pipeline
a-r-r-o-w Jul 22, 2024
659310a
pag variant of hunyuan dit
a-r-r-o-w Jul 22, 2024
21af0a7
add tests
a-r-r-o-w Jul 22, 2024
e26bce5
update docs
a-r-r-o-w Jul 22, 2024
8c7ca9e
make style
a-r-r-o-w Jul 22, 2024
f83a463
make fix-copies
a-r-r-o-w Jul 22, 2024
196e687
Update src/diffusers/pipelines/pag/pag_utils.py
a-r-r-o-w Jul 22, 2024
7998502
remove incorrect copied from
a-r-r-o-w Jul 23, 2024
096ded3
remove pag hunyuan attn procs to resolve conflicts
a-r-r-o-w Jul 26, 2024
98b475d
Merge branch 'main' into hunyuan-dit-pag
a-r-r-o-w Jul 26, 2024
d397ed4
add pag attn procs again
a-r-r-o-w Jul 26, 2024
4a06691
Merge branch 'main' into hunyuan-dit-pag
a-r-r-o-w Jul 26, 2024
d9638d9
new implementation for pag_utils
a-r-r-o-w Jul 26, 2024
ab34d38
revert pag changes
a-r-r-o-w Aug 2, 2024
65f2f06
Merge branch 'main' into hunyuan-dit-pag
a-r-r-o-w Aug 2, 2024
199aa10
add pag refactor back; update pixart sigma
a-r-r-o-w Aug 2, 2024
5768906
update pixart pag tests
a-r-r-o-w Aug 2, 2024
e754f50
apply suggestions from review
a-r-r-o-w Aug 2, 2024
b05da7c
make style
a-r-r-o-w Aug 2, 2024
0f3d232
Merge branch 'main' into hunyuan-dit-pag
a-r-r-o-w Aug 2, 2024
6e38d4e
Merge branch 'main' into hunyuan-dit-pag
sayakpaul Aug 4, 2024
678dce9
Merge branch 'main' into hunyuan-dit-pag
a-r-r-o-w Aug 4, 2024
f51424b
Merge branch 'main' into hunyuan-dit-pag
sayakpaul Aug 4, 2024
f36cc5c
Merge branch 'hunyuan-dit-pag' of https://github.com/huggingface/diff…
a-r-r-o-w Aug 4, 2024
69c3250
update docs, fix tests
a-r-r-o-w Aug 4, 2024
753a023
fix tests
a-r-r-o-w Aug 4, 2024
c0d6f72
Merge branch 'main' into hunyuan-dit-pag
a-r-r-o-w Aug 4, 2024
6f549d1
fix test_components_function since list not accepted as valid __init_…
a-r-r-o-w Aug 4, 2024
713c2d1
Merge branch 'main' into hunyuan-dit-pag
sayakpaul Aug 5, 2024
de99a59
apply patch to fix broken tests
a-r-r-o-w Aug 5, 2024
473645d
make style
a-r-r-o-w Aug 5, 2024
29f54e4
fix hunyuan tests
a-r-r-o-w Aug 5, 2024
02f9b2e
Merge branch 'main' into hunyuan-dit-pag
a-r-r-o-w Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,29 @@ The abstract from the paper is:

*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*

PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.

- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`
- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`
- Partial identifier as a RegEx: `down_blocks.2`, or `attn1`
- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]`

<Tip warning={true}>

Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results.

</Tip>

## AnimateDiffPAGPipeline
[[autodoc]] AnimateDiffPAGPipeline
- all
- __call__

## HunyuanDiTPAGPipeline
[[autodoc]] HunyuanDiTPAGPipeline
- all
- __call__

## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
Expand Down Expand Up @@ -59,4 +77,4 @@ The abstract from the paper is:
## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
- all
- __call__
- __call__
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@
"CycleDiffusionPipeline",
"FluxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
Expand Down Expand Up @@ -675,6 +676,7 @@
CycleDiffusionPipeline,
FluxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
Expand Down
249 changes: 249 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2147,6 +2147,253 @@ def __call__(
return hidden_states


class PAGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)

# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states_org)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)

# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)

hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)

# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class PAGCFGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])

# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states_org)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)

# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)

hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)

# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
Expand Down Expand Up @@ -3468,4 +3715,6 @@ def __init__(self):
CustomDiffusionAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
]
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
_import_structure["pag"].extend(
[
"AnimateDiffPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
Expand Down Expand Up @@ -532,6 +533,7 @@
from .musicldm import MusicLDMPipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
Expand Down Expand Up @@ -85,6 +86,7 @@
("stable-diffusion-3", StableDiffusion3Pipeline),
("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline),
("hunyuan-pag", HunyuanDiTPAGPipeline),
("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline),
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
Expand All @@ -41,6 +42,7 @@
else:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
Expand Down
Loading
Loading