Skip to content

Commit fafd774

Browse files
author
davidb
committed
unify the structure of the forward block
1 parent 3375933 commit fafd774

File tree

2 files changed

+54
-74
lines changed

2 files changed

+54
-74
lines changed

src/diffusers/models/transformers/transformer_photon.py

Lines changed: 44 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,8 @@ class PhotonBlock(nn.Module):
299299
Produces scale/shift/gating parameters for modulated layers.
300300
301301
Methods:
302-
attn_forward(img, txt, pe, modulation, spatial_conditioning=None, attention_mask=None):
303-
Compute cross-attention between image and text tokens, with optional spatial conditioning and attention
304-
masking.
302+
attn_forward(img, txt, pe, modulation, attention_mask=None):
303+
Compute cross-attention between image and text tokens, with optional attention masking.
305304
306305
Parameters:
307306
img (`torch.Tensor`):
@@ -312,8 +311,6 @@ class PhotonBlock(nn.Module):
312311
Rotary positional embeddings to apply to queries and keys.
313312
modulation (`ModulationOut`):
314313
Scale and shift parameters for modulating image tokens.
315-
spatial_conditioning (`torch.Tensor`, *optional*):
316-
Extra conditioning tokens of shape `(B, L_cond, hidden_size)`.
317314
attention_mask (`torch.Tensor`, *optional*):
318315
Boolean mask of shape `(B, L_txt)` where 0 marks padding.
319316
@@ -372,7 +369,6 @@ def _attn_forward(
372369
txt: Tensor,
373370
pe: Tensor,
374371
modulation: ModulationOut,
375-
spatial_conditioning: None | Tensor = None,
376372
attention_mask: None | Tensor = None,
377373
) -> Tensor:
378374
# image tokens proj and norm
@@ -444,7 +440,6 @@ def forward(
444440
txt: Tensor,
445441
vec: Tensor,
446442
pe: Tensor,
447-
spatial_conditioning: Tensor | None = None,
448443
attention_mask: Tensor | None = None,
449444
**_: dict[str, Any],
450445
) -> Tensor:
@@ -461,9 +456,6 @@ def forward(
461456
broadcastable).
462457
pe (`torch.Tensor`):
463458
Rotary positional embeddings applied inside attention.
464-
spatial_conditioning (`torch.Tensor`, *optional*):
465-
Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only if spatial conditioning is
466-
enabled in the block.
467459
attention_mask (`torch.Tensor`, *optional*):
468460
Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
469461
**_:
@@ -481,7 +473,6 @@ def forward(
481473
txt,
482474
pe,
483475
mod_attn,
484-
spatial_conditioning=spatial_conditioning,
485476
attention_mask=attention_mask,
486477
)
487478
img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp)
@@ -698,14 +689,6 @@ def __init__(
698689

699690
self.gradient_checkpointing = False
700691

701-
def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]:
702-
txt = self.txt_in(txt)
703-
img = img2seq(image_latent, self.patch_size)
704-
bs, _, h, w = image_latent.shape
705-
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device)
706-
pe = self.pe_embedder(img_ids)
707-
return img, txt, pe
708-
709692
def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor:
710693
return self.time_in(
711694
get_timestep_embedding(
@@ -717,43 +700,6 @@ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> T
717700
).to(dtype)
718701
)
719702

720-
def _forward_transformers(
721-
self,
722-
image_latent: Tensor,
723-
cross_attn_conditioning: Tensor,
724-
timestep: Optional[Tensor] = None,
725-
time_embedding: Optional[Tensor] = None,
726-
attention_mask: Optional[Tensor] = None,
727-
**block_kwargs: Any,
728-
) -> Tensor:
729-
img = self.img_in(image_latent)
730-
731-
if time_embedding is not None:
732-
vec = time_embedding
733-
else:
734-
if timestep is None:
735-
raise ValueError("Please provide either a timestep or a timestep_embedding")
736-
vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
737-
738-
for block in self.blocks:
739-
if torch.is_grad_enabled() and self.gradient_checkpointing:
740-
img = self._gradient_checkpointing_func(
741-
block.__call__,
742-
img,
743-
cross_attn_conditioning,
744-
vec,
745-
block_kwargs.get("pe"),
746-
block_kwargs.get("spatial_conditioning"),
747-
attention_mask,
748-
)
749-
else:
750-
img = block(
751-
img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs
752-
)
753-
754-
img = self.final_layer(img, vec)
755-
return img
756-
757703
def forward(
758704
self,
759705
image_latent: Tensor,
@@ -797,6 +743,7 @@ def forward(
797743
lora_scale = attention_kwargs.pop("scale", 1.0)
798744
else:
799745
lora_scale = 1.0
746+
800747
if USE_PEFT_BACKEND:
801748
# weight the lora layers by setting `lora_scale` for each PEFT layer
802749
scale_lora_layers(self, lora_scale)
@@ -805,12 +752,50 @@ def forward(
805752
logger.warning(
806753
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
807754
)
808-
img_seq, txt, pe = self._process_inputs(image_latent, cross_attn_conditioning)
809-
img_seq = self._forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask)
810-
output = seq2img(img_seq, self.patch_size, image_latent.shape)
755+
756+
# Process text conditioning
757+
txt = self.txt_in(cross_attn_conditioning)
758+
759+
# Convert image to sequence and embed
760+
img = img2seq(image_latent, self.patch_size)
761+
img = self.img_in(img)
762+
763+
# Generate positional embeddings
764+
bs, _, h, w = image_latent.shape
765+
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device)
766+
pe = self.pe_embedder(img_ids)
767+
768+
# Compute time embedding
769+
vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
770+
771+
# Apply transformer blocks
772+
for block in self.blocks:
773+
if torch.is_grad_enabled() and self.gradient_checkpointing:
774+
img = self._gradient_checkpointing_func(
775+
block.__call__,
776+
img,
777+
txt,
778+
vec,
779+
pe,
780+
cross_attn_mask,
781+
)
782+
else:
783+
img = block(
784+
img=img,
785+
txt=txt,
786+
vec=vec,
787+
pe=pe,
788+
attention_mask=cross_attn_mask,
789+
)
790+
791+
# Final layer and convert back to image
792+
img = self.final_layer(img, vec)
793+
output = seq2img(img, self.patch_size, image_latent.shape)
794+
811795
if USE_PEFT_BACKEND:
812796
# remove `lora_scale` from each PEFT layer
813797
unscale_lora_layers(self, lora_scale)
798+
814799
if not return_dict:
815800
return (output,)
816801
return Transformer2DModelOutput(sample=output)

src/diffusers/pipelines/photon/pipeline_photon.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from diffusers.image_processor import PixArtImageProcessor
3232
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
3333
from diffusers.models import AutoencoderDC, AutoencoderKL
34-
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img
34+
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
3535
from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput
3636
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
3737
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
@@ -572,20 +572,15 @@ def __call__(
572572
# Normalize timestep for the transformer
573573
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
574574

575-
# Process inputs for transformer
576-
img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed)
577-
578-
# Forward through transformer layers
579-
img_seq = self.transformer._forward_transformers(
580-
img_seq,
581-
txt,
582-
time_embedding=self.transformer._compute_timestep_embedding(t_cont, img_seq.dtype),
583-
pe=pe,
584-
attention_mask=ca_mask,
585-
)
586-
587-
# Convert back to image format
588-
noise_pred = seq2img(img_seq, self.transformer.patch_size, latents_in.shape)
575+
# Forward through transformer
576+
noise_pred = self.transformer(
577+
image_latent=latents_in,
578+
timestep=t_cont,
579+
cross_attn_conditioning=ca_embed,
580+
micro_conditioning=None,
581+
cross_attn_mask=ca_mask,
582+
return_dict=False,
583+
)[0]
589584

590585
# Apply CFG
591586
if self.do_classifier_free_guidance:

0 commit comments

Comments
 (0)