Skip to content

Commit 3375933

Browse files
author
davidb
committed
remove einops dependency and now inherits from AttentionMixin
1 parent b531150 commit 3375933

File tree

1 file changed

+16
-66
lines changed

1 file changed

+16
-66
lines changed

src/diffusers/models/transformers/transformer_photon.py

Lines changed: 16 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
from typing import Any, Dict, Optional, Tuple, Union
1717

1818
import torch
19-
from einops import rearrange
20-
from einops.layers.torch import Rearrange
2119
from torch import Tensor, nn
2220
from torch.nn.functional import fold, unfold
2321

2422
from ...configuration_utils import ConfigMixin, register_to_config
2523
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24+
from ..attention import AttentionMixin
2625
from ..attention_processor import Attention, AttentionProcessor
2726
from ..embeddings import get_timestep_embedding
2827
from ..modeling_outputs import Transformer2DModelOutput
@@ -134,6 +133,7 @@ def __call__(
134133
attn_output = attn.to_out[1](attn_output) # dropout if present
135134

136135
return attn_output
136+
# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
137137
class EmbedND(nn.Module):
138138
r"""
139139
N-dimensional rotary positional embedding.
@@ -155,15 +155,16 @@ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
155155
self.dim = dim
156156
self.theta = theta
157157
self.axes_dim = axes_dim
158-
self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
159158

160159
def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor:
161160
assert dim % 2 == 0
162161
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
163162
omega = 1.0 / (theta**scale)
164163
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
165164
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
166-
out = self.rope_rearrange(out)
165+
# Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
166+
# out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
167+
out = out.reshape(*out.shape[:-1], 2, 2)
167168
return out.float()
168169

169170
def forward(self, ids: Tensor) -> Tensor:
@@ -378,12 +379,20 @@ def _attn_forward(
378379
img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift
379380

380381
img_qkv = self.img_qkv_proj(img_mod)
381-
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
382+
# Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
383+
B, L, _ = img_qkv.shape
384+
img_qkv = img_qkv.reshape(B, L, 3, self.num_heads, self.head_dim) # (B, L, K, H, D)
385+
img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D)
386+
img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
382387
img_q, img_k = self.qk_norm(img_q, img_k, img_v)
383388

384389
# txt tokens proj and norm
385390
txt_kv = self.txt_kv_proj(txt)
386-
txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads)
391+
# Native PyTorch equivalent of: rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads)
392+
B, L, _ = txt_kv.shape
393+
txt_kv = txt_kv.reshape(B, L, 2, self.num_heads, self.head_dim) # (B, L, K, H, D)
394+
txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # (K, B, H, L, D)
395+
txt_k, txt_v = txt_kv[0], txt_kv[1]
387396
txt_k = self.k_norm(txt_k)
388397

389398
# compute attention
@@ -564,7 +573,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
564573
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
565574

566575

567-
class PhotonTransformer2DModel(ModelMixin, ConfigMixin):
576+
class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
568577
r"""
569578
Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA
570579
scaling.
@@ -689,65 +698,6 @@ def __init__(
689698

690699
self.gradient_checkpointing = False
691700

692-
@property
693-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
694-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
695-
r"""
696-
Returns:
697-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
698-
indexed by its weight name.
699-
"""
700-
# set recursively
701-
processors = {}
702-
703-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
704-
if hasattr(module, "get_processor"):
705-
processors[f"{name}.processor"] = module.get_processor()
706-
707-
for sub_name, child in module.named_children():
708-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
709-
710-
return processors
711-
712-
for name, module in self.named_children():
713-
fn_recursive_add_processors(name, module, processors)
714-
715-
return processors
716-
717-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
718-
r"""
719-
Sets the attention processor to use to compute attention.
720-
721-
Parameters:
722-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
723-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
724-
for **all** `Attention` layers.
725-
726-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
727-
processor. This is strongly recommended when setting trainable attention processors.
728-
729-
"""
730-
count = len(self.attn_processors.keys())
731-
732-
if isinstance(processor, dict) and len(processor) != count:
733-
raise ValueError(
734-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
735-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
736-
)
737-
738-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
739-
if hasattr(module, "set_processor"):
740-
if not isinstance(processor, dict):
741-
module.set_processor(processor)
742-
else:
743-
module.set_processor(processor.pop(f"{name}.processor"))
744-
745-
for sub_name, child in module.named_children():
746-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
747-
748-
for name, module in self.named_children():
749-
fn_recursive_attn_processor(name, module, processor)
750-
751701
def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]:
752702
txt = self.txt_in(txt)
753703
img = img2seq(image_latent, self.patch_size)

0 commit comments

Comments
 (0)