@@ -299,9 +299,8 @@ class PhotonBlock(nn.Module):
299299 Produces scale/shift/gating parameters for modulated layers.
300300
301301 Methods:
302- attn_forward(img, txt, pe, modulation, spatial_conditioning=None, attention_mask=None):
303- Compute cross-attention between image and text tokens, with optional spatial conditioning and attention
304- masking.
302+ attn_forward(img, txt, pe, modulation, attention_mask=None):
303+ Compute cross-attention between image and text tokens, with optional attention masking.
305304
306305 Parameters:
307306 img (`torch.Tensor`):
@@ -312,8 +311,6 @@ class PhotonBlock(nn.Module):
312311 Rotary positional embeddings to apply to queries and keys.
313312 modulation (`ModulationOut`):
314313 Scale and shift parameters for modulating image tokens.
315- spatial_conditioning (`torch.Tensor`, *optional*):
316- Extra conditioning tokens of shape `(B, L_cond, hidden_size)`.
317314 attention_mask (`torch.Tensor`, *optional*):
318315 Boolean mask of shape `(B, L_txt)` where 0 marks padding.
319316
@@ -372,7 +369,6 @@ def _attn_forward(
372369 txt : Tensor ,
373370 pe : Tensor ,
374371 modulation : ModulationOut ,
375- spatial_conditioning : None | Tensor = None ,
376372 attention_mask : None | Tensor = None ,
377373 ) -> Tensor :
378374 # image tokens proj and norm
@@ -444,7 +440,6 @@ def forward(
444440 txt : Tensor ,
445441 vec : Tensor ,
446442 pe : Tensor ,
447- spatial_conditioning : Tensor | None = None ,
448443 attention_mask : Tensor | None = None ,
449444 ** _ : dict [str , Any ],
450445 ) -> Tensor :
@@ -461,9 +456,6 @@ def forward(
461456 broadcastable).
462457 pe (`torch.Tensor`):
463458 Rotary positional embeddings applied inside attention.
464- spatial_conditioning (`torch.Tensor`, *optional*):
465- Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only if spatial conditioning is
466- enabled in the block.
467459 attention_mask (`torch.Tensor`, *optional*):
468460 Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
469461 **_:
@@ -481,7 +473,6 @@ def forward(
481473 txt ,
482474 pe ,
483475 mod_attn ,
484- spatial_conditioning = spatial_conditioning ,
485476 attention_mask = attention_mask ,
486477 )
487478 img = img + mod_mlp .gate * self ._ffn_forward (img , mod_mlp )
@@ -698,14 +689,6 @@ def __init__(
698689
699690 self .gradient_checkpointing = False
700691
701- def _process_inputs (self , image_latent : Tensor , txt : Tensor , ** _ : Any ) -> tuple [Tensor , Tensor , Tensor ]:
702- txt = self .txt_in (txt )
703- img = img2seq (image_latent , self .patch_size )
704- bs , _ , h , w = image_latent .shape
705- img_ids = get_image_ids (bs , h , w , patch_size = self .patch_size , device = image_latent .device )
706- pe = self .pe_embedder (img_ids )
707- return img , txt , pe
708-
709692 def _compute_timestep_embedding (self , timestep : Tensor , dtype : torch .dtype ) -> Tensor :
710693 return self .time_in (
711694 get_timestep_embedding (
@@ -717,43 +700,6 @@ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> T
717700 ).to (dtype )
718701 )
719702
720- def _forward_transformers (
721- self ,
722- image_latent : Tensor ,
723- cross_attn_conditioning : Tensor ,
724- timestep : Optional [Tensor ] = None ,
725- time_embedding : Optional [Tensor ] = None ,
726- attention_mask : Optional [Tensor ] = None ,
727- ** block_kwargs : Any ,
728- ) -> Tensor :
729- img = self .img_in (image_latent )
730-
731- if time_embedding is not None :
732- vec = time_embedding
733- else :
734- if timestep is None :
735- raise ValueError ("Please provide either a timestep or a timestep_embedding" )
736- vec = self ._compute_timestep_embedding (timestep , dtype = img .dtype )
737-
738- for block in self .blocks :
739- if torch .is_grad_enabled () and self .gradient_checkpointing :
740- img = self ._gradient_checkpointing_func (
741- block .__call__ ,
742- img ,
743- cross_attn_conditioning ,
744- vec ,
745- block_kwargs .get ("pe" ),
746- block_kwargs .get ("spatial_conditioning" ),
747- attention_mask ,
748- )
749- else :
750- img = block (
751- img = img , txt = cross_attn_conditioning , vec = vec , attention_mask = attention_mask , ** block_kwargs
752- )
753-
754- img = self .final_layer (img , vec )
755- return img
756-
757703 def forward (
758704 self ,
759705 image_latent : Tensor ,
@@ -797,6 +743,7 @@ def forward(
797743 lora_scale = attention_kwargs .pop ("scale" , 1.0 )
798744 else :
799745 lora_scale = 1.0
746+
800747 if USE_PEFT_BACKEND :
801748 # weight the lora layers by setting `lora_scale` for each PEFT layer
802749 scale_lora_layers (self , lora_scale )
@@ -805,12 +752,50 @@ def forward(
805752 logger .warning (
806753 "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
807754 )
808- img_seq , txt , pe = self ._process_inputs (image_latent , cross_attn_conditioning )
809- img_seq = self ._forward_transformers (img_seq , txt , timestep , pe = pe , attention_mask = cross_attn_mask )
810- output = seq2img (img_seq , self .patch_size , image_latent .shape )
755+
756+ # Process text conditioning
757+ txt = self .txt_in (cross_attn_conditioning )
758+
759+ # Convert image to sequence and embed
760+ img = img2seq (image_latent , self .patch_size )
761+ img = self .img_in (img )
762+
763+ # Generate positional embeddings
764+ bs , _ , h , w = image_latent .shape
765+ img_ids = get_image_ids (bs , h , w , patch_size = self .patch_size , device = image_latent .device )
766+ pe = self .pe_embedder (img_ids )
767+
768+ # Compute time embedding
769+ vec = self ._compute_timestep_embedding (timestep , dtype = img .dtype )
770+
771+ # Apply transformer blocks
772+ for block in self .blocks :
773+ if torch .is_grad_enabled () and self .gradient_checkpointing :
774+ img = self ._gradient_checkpointing_func (
775+ block .__call__ ,
776+ img ,
777+ txt ,
778+ vec ,
779+ pe ,
780+ cross_attn_mask ,
781+ )
782+ else :
783+ img = block (
784+ img = img ,
785+ txt = txt ,
786+ vec = vec ,
787+ pe = pe ,
788+ attention_mask = cross_attn_mask ,
789+ )
790+
791+ # Final layer and convert back to image
792+ img = self .final_layer (img , vec )
793+ output = seq2img (img , self .patch_size , image_latent .shape )
794+
811795 if USE_PEFT_BACKEND :
812796 # remove `lora_scale` from each PEFT layer
813797 unscale_lora_layers (self , lora_scale )
798+
814799 if not return_dict :
815800 return (output ,)
816801 return Transformer2DModelOutput (sample = output )
0 commit comments