Skip to content

Commit 8646908

Browse files
committed
fixed counter manage, cogvoideox missing norm proj added
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
1 parent 4a6afef commit 8646908

File tree

2 files changed

+85
-20
lines changed

2 files changed

+85
-20
lines changed

src/diffusers/hooks/teacache.py

Lines changed: 83 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ def __init__(self):
289289
self.previous_residual = None
290290
# For models that cache both encoder and hidden_states residuals (e.g., CogVideoX)
291291
self.previous_residual_encoder = None
292+
# For models with variable sequence lengths (e.g., Lumina2)
293+
self.cache_dict = {}
294+
self.uncond_seq_len = None
292295

293296
def reset(self):
294297
"""Reset all state variables to initial values for a new inference run."""
@@ -298,6 +301,8 @@ def reset(self):
298301
self.previous_modulated_input = None
299302
self.previous_residual = None
300303
self.previous_residual_encoder = None
304+
self.cache_dict = {}
305+
self.uncond_seq_len = None
301306

302307
def __repr__(self) -> str:
303308
return (
@@ -634,7 +639,7 @@ def _handle_mochi_forward(
634639
should_calc = self._should_compute_full_transformer(state, modulated_inp)
635640

636641
if not should_calc:
637-
# Fast path: apply cached residual
642+
# Fast path: apply cached residual (already includes norm_out)
638643
hidden_states = hidden_states + state.previous_residual
639644
else:
640645
# Slow path: full computation
@@ -660,14 +665,16 @@ def _handle_mochi_forward(
660665
image_rotary_emb=image_rotary_emb,
661666
)
662667

663-
# Cache the residual
668+
# Apply norm_out before caching residual (matches reference implementation)
669+
hidden_states = module.norm_out(hidden_states, temb)
670+
671+
# Cache the residual (includes norm_out transformation)
664672
state.previous_residual = hidden_states - ori_hidden_states
665673

666674
state.previous_modulated_input = modulated_inp
667675
state.cnt += 1
668676

669-
# Apply final norm and projection
670-
hidden_states = module.norm_out(hidden_states, temb)
677+
# Apply projection
671678
hidden_states = module.proj_out(hidden_states)
672679

673680
# Reshape output
@@ -765,12 +772,57 @@ def _handle_lumina2_forward(
765772
# Extract modulated input (after preprocessing)
766773
modulated_inp = self.extractor_fn(module, input_to_main_loop, temb)
767774

768-
# Make caching decision
769-
should_calc = self._should_compute_full_transformer(state, modulated_inp)
770-
771-
if not should_calc:
775+
# Per-sequence-length cache for Lumina2 (handles variable sequence lengths)
776+
cache_key = max_seq_len
777+
if cache_key not in state.cache_dict:
778+
state.cache_dict[cache_key] = {
779+
"previous_modulated_input": None,
780+
"previous_residual": None,
781+
"accumulated_rel_l1_distance": 0.0,
782+
}
783+
current_cache = state.cache_dict[cache_key]
784+
785+
# Make caching decision using per-cache values
786+
if state.cnt == 0 or state.cnt == state.num_steps - 1:
787+
should_calc = True
788+
current_cache["accumulated_rel_l1_distance"] = 0.0
789+
else:
790+
if current_cache["previous_modulated_input"] is not None:
791+
prev_mod_input = current_cache["previous_modulated_input"]
792+
prev_mean = prev_mod_input.abs().mean()
793+
794+
if prev_mean.item() > 1e-9:
795+
rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item()
796+
else:
797+
rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf')
798+
799+
rescaled_distance = self.rescale_func(rel_l1_change)
800+
current_cache["accumulated_rel_l1_distance"] += rescaled_distance
801+
802+
if current_cache["accumulated_rel_l1_distance"] < self.config.rel_l1_thresh:
803+
should_calc = False
804+
else:
805+
should_calc = True
806+
current_cache["accumulated_rel_l1_distance"] = 0.0
807+
else:
808+
should_calc = True
809+
current_cache["accumulated_rel_l1_distance"] = 0.0
810+
811+
current_cache["previous_modulated_input"] = modulated_inp.clone()
812+
813+
# Track unconditional sequence length for counter management
814+
if state.uncond_seq_len is None:
815+
state.uncond_seq_len = cache_key
816+
# Only increment counter when not processing unconditional (different seq len)
817+
if cache_key != state.uncond_seq_len:
818+
state.cnt += 1
819+
if state.cnt >= state.num_steps:
820+
state.cnt = 0
821+
822+
# Fast or slow path with per-cache residual
823+
if not should_calc and current_cache["previous_residual"] is not None:
772824
# Fast path: apply cached residual
773-
processed_hidden_states = input_to_main_loop + state.previous_residual
825+
processed_hidden_states = input_to_main_loop + current_cache["previous_residual"]
774826
else:
775827
# Slow path: full computation
776828
current_processing_states = input_to_main_loop
@@ -779,11 +831,8 @@ def _handle_lumina2_forward(
779831
current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb
780832
)
781833
processed_hidden_states = current_processing_states
782-
# Cache the residual
783-
state.previous_residual = processed_hidden_states - input_to_main_loop
784-
785-
state.previous_modulated_input = modulated_inp
786-
state.cnt += 1
834+
# Cache the residual in per-cache storage
835+
current_cache["previous_residual"] = processed_hidden_states - input_to_main_loop
787836

788837
# Apply final norm and reshape
789838
output_after_norm = module.norm_out(processed_hidden_states, temb)
@@ -881,13 +930,13 @@ def _handle_cogvideox_forward(
881930
# Make caching decision
882931
should_calc = self._should_compute_full_transformer(state, modulated_inp)
883932

884-
# Fast path: apply cached residuals (both encoder and hidden_states)
885-
# Must have both residuals cached to use fast path
886-
if not should_calc and state.previous_residual_encoder is not None:
933+
# Fast or slow path based on caching decision
934+
if not should_calc:
935+
# Fast path: apply cached residuals (both encoder and hidden_states)
887936
hidden_states = hidden_states + state.previous_residual
888937
encoder_hidden_states = encoder_hidden_states + state.previous_residual_encoder
889938
else:
890-
# Slow path: full computation (also runs when encoder residual not yet cached)
939+
# Slow path: full computation
891940
ori_hidden_states = hidden_states.clone()
892941
ori_encoder_hidden_states = encoder_hidden_states.clone()
893942

@@ -928,7 +977,22 @@ def _handle_cogvideox_forward(
928977
hidden_states = module.norm_final(hidden_states)
929978
hidden_states = hidden_states[:, text_seq_length:]
930979

931-
output = module.proj_out(hidden_states)
980+
# Final block
981+
hidden_states = module.norm_out(hidden_states, temb=emb)
982+
hidden_states = module.proj_out(hidden_states)
983+
984+
# Unpatchify
985+
p = module.config.patch_size
986+
p_t = module.config.patch_size_t
987+
988+
if p_t is None:
989+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
990+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
991+
else:
992+
output = hidden_states.reshape(
993+
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
994+
)
995+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
932996

933997
if USE_PEFT_BACKEND:
934998
unscale_lora_layers(module, lora_scale)

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2626
from ..attention import LuminaFeedForward
2727
from ..attention_processor import Attention
28+
from ..cache_utils import CacheMixin
2829
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
2930
from ..modeling_outputs import Transformer2DModelOutput
3031
from ..modeling_utils import ModelMixin
@@ -322,7 +323,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
322323
return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
323324

324325

325-
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
326+
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
326327
r"""
327328
Lumina2NextDiT: Diffusion model with a Transformer backbone.
328329

0 commit comments

Comments
 (0)