1616from typing import Any , Dict , Optional , Tuple , Union
1717
1818import torch
19- from einops import rearrange
20- from einops .layers .torch import Rearrange
2119from torch import Tensor , nn
2220from torch .nn .functional import fold , unfold
2321
2422from ...configuration_utils import ConfigMixin , register_to_config
2523from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
24+ from ..attention import AttentionMixin
2625from ..attention_processor import Attention , AttentionProcessor
2726from ..embeddings import get_timestep_embedding
2827from ..modeling_outputs import Transformer2DModelOutput
@@ -134,6 +133,7 @@ def __call__(
134133 attn_output = attn .to_out [1 ](attn_output ) # dropout if present
135134
136135 return attn_output
136+ # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
137137class EmbedND (nn .Module ):
138138 r"""
139139 N-dimensional rotary positional embedding.
@@ -155,15 +155,16 @@ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
155155 self .dim = dim
156156 self .theta = theta
157157 self .axes_dim = axes_dim
158- self .rope_rearrange = Rearrange ("b n d (i j) -> b n d i j" , i = 2 , j = 2 )
159158
160159 def rope (self , pos : Tensor , dim : int , theta : int ) -> Tensor :
161160 assert dim % 2 == 0
162161 scale = torch .arange (0 , dim , 2 , dtype = torch .float64 , device = pos .device ) / dim
163162 omega = 1.0 / (theta ** scale )
164163 out = pos .unsqueeze (- 1 ) * omega .unsqueeze (0 )
165164 out = torch .stack ([torch .cos (out ), - torch .sin (out ), torch .sin (out ), torch .cos (out )], dim = - 1 )
166- out = self .rope_rearrange (out )
165+ # Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
166+ # out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
167+ out = out .reshape (* out .shape [:- 1 ], 2 , 2 )
167168 return out .float ()
168169
169170 def forward (self , ids : Tensor ) -> Tensor :
@@ -378,12 +379,20 @@ def _attn_forward(
378379 img_mod = (1 + modulation .scale ) * self .img_pre_norm (img ) + modulation .shift
379380
380381 img_qkv = self .img_qkv_proj (img_mod )
381- img_q , img_k , img_v = rearrange (img_qkv , "B L (K H D) -> K B H L D" , K = 3 , H = self .num_heads )
382+ # Native PyTorch equivalent of: rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
383+ B , L , _ = img_qkv .shape
384+ img_qkv = img_qkv .reshape (B , L , 3 , self .num_heads , self .head_dim ) # (B, L, K, H, D)
385+ img_qkv = img_qkv .permute (2 , 0 , 3 , 1 , 4 ) # (K, B, H, L, D)
386+ img_q , img_k , img_v = img_qkv [0 ], img_qkv [1 ], img_qkv [2 ]
382387 img_q , img_k = self .qk_norm (img_q , img_k , img_v )
383388
384389 # txt tokens proj and norm
385390 txt_kv = self .txt_kv_proj (txt )
386- txt_k , txt_v = rearrange (txt_kv , "B L (K H D) -> K B H L D" , K = 2 , H = self .num_heads )
391+ # Native PyTorch equivalent of: rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads)
392+ B , L , _ = txt_kv .shape
393+ txt_kv = txt_kv .reshape (B , L , 2 , self .num_heads , self .head_dim ) # (B, L, K, H, D)
394+ txt_kv = txt_kv .permute (2 , 0 , 3 , 1 , 4 ) # (K, B, H, L, D)
395+ txt_k , txt_v = txt_kv [0 ], txt_kv [1 ]
387396 txt_k = self .k_norm (txt_k )
388397
389398 # compute attention
@@ -564,7 +573,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
564573 return fold (seq .transpose (1 , 2 ), shape , kernel_size = patch_size , stride = patch_size )
565574
566575
567- class PhotonTransformer2DModel (ModelMixin , ConfigMixin ):
576+ class PhotonTransformer2DModel (ModelMixin , ConfigMixin , AttentionMixin ):
568577 r"""
569578 Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA
570579 scaling.
@@ -689,65 +698,6 @@ def __init__(
689698
690699 self .gradient_checkpointing = False
691700
692- @property
693- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
694- def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
695- r"""
696- Returns:
697- `dict` of attention processors: A dictionary containing all attention processors used in the model with
698- indexed by its weight name.
699- """
700- # set recursively
701- processors = {}
702-
703- def fn_recursive_add_processors (name : str , module : torch .nn .Module , processors : Dict [str , AttentionProcessor ]):
704- if hasattr (module , "get_processor" ):
705- processors [f"{ name } .processor" ] = module .get_processor ()
706-
707- for sub_name , child in module .named_children ():
708- fn_recursive_add_processors (f"{ name } .{ sub_name } " , child , processors )
709-
710- return processors
711-
712- for name , module in self .named_children ():
713- fn_recursive_add_processors (name , module , processors )
714-
715- return processors
716-
717- def set_attn_processor (self , processor : Union [AttentionProcessor , Dict [str , AttentionProcessor ]]):
718- r"""
719- Sets the attention processor to use to compute attention.
720-
721- Parameters:
722- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
723- The instantiated processor class or a dictionary of processor classes that will be set as the processor
724- for **all** `Attention` layers.
725-
726- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
727- processor. This is strongly recommended when setting trainable attention processors.
728-
729- """
730- count = len (self .attn_processors .keys ())
731-
732- if isinstance (processor , dict ) and len (processor ) != count :
733- raise ValueError (
734- f"A dict of processors was passed, but the number of processors { len (processor )} does not match the"
735- f" number of attention layers: { count } . Please make sure to pass { count } processor classes."
736- )
737-
738- def fn_recursive_attn_processor (name : str , module : torch .nn .Module , processor ):
739- if hasattr (module , "set_processor" ):
740- if not isinstance (processor , dict ):
741- module .set_processor (processor )
742- else :
743- module .set_processor (processor .pop (f"{ name } .processor" ))
744-
745- for sub_name , child in module .named_children ():
746- fn_recursive_attn_processor (f"{ name } .{ sub_name } " , child , processor )
747-
748- for name , module in self .named_children ():
749- fn_recursive_attn_processor (name , module , processor )
750-
751701 def _process_inputs (self , image_latent : Tensor , txt : Tensor , ** _ : Any ) -> tuple [Tensor , Tensor , Tensor ]:
752702 txt = self .txt_in (txt )
753703 img = img2seq (image_latent , self .patch_size )
0 commit comments