@@ -289,6 +289,9 @@ def __init__(self):
289289 self .previous_residual = None
290290 # For models that cache both encoder and hidden_states residuals (e.g., CogVideoX)
291291 self .previous_residual_encoder = None
292+ # For models with variable sequence lengths (e.g., Lumina2)
293+ self .cache_dict = {}
294+ self .uncond_seq_len = None
292295
293296 def reset (self ):
294297 """Reset all state variables to initial values for a new inference run."""
@@ -298,6 +301,8 @@ def reset(self):
298301 self .previous_modulated_input = None
299302 self .previous_residual = None
300303 self .previous_residual_encoder = None
304+ self .cache_dict = {}
305+ self .uncond_seq_len = None
301306
302307 def __repr__ (self ) -> str :
303308 return (
@@ -634,7 +639,7 @@ def _handle_mochi_forward(
634639 should_calc = self ._should_compute_full_transformer (state , modulated_inp )
635640
636641 if not should_calc :
637- # Fast path: apply cached residual
642+ # Fast path: apply cached residual (already includes norm_out)
638643 hidden_states = hidden_states + state .previous_residual
639644 else :
640645 # Slow path: full computation
@@ -660,14 +665,16 @@ def _handle_mochi_forward(
660665 image_rotary_emb = image_rotary_emb ,
661666 )
662667
663- # Cache the residual
668+ # Apply norm_out before caching residual (matches reference implementation)
669+ hidden_states = module .norm_out (hidden_states , temb )
670+
671+ # Cache the residual (includes norm_out transformation)
664672 state .previous_residual = hidden_states - ori_hidden_states
665673
666674 state .previous_modulated_input = modulated_inp
667675 state .cnt += 1
668676
669- # Apply final norm and projection
670- hidden_states = module .norm_out (hidden_states , temb )
677+ # Apply projection
671678 hidden_states = module .proj_out (hidden_states )
672679
673680 # Reshape output
@@ -765,12 +772,57 @@ def _handle_lumina2_forward(
765772 # Extract modulated input (after preprocessing)
766773 modulated_inp = self .extractor_fn (module , input_to_main_loop , temb )
767774
768- # Make caching decision
769- should_calc = self ._should_compute_full_transformer (state , modulated_inp )
770-
771- if not should_calc :
775+ # Per-sequence-length cache for Lumina2 (handles variable sequence lengths)
776+ cache_key = max_seq_len
777+ if cache_key not in state .cache_dict :
778+ state .cache_dict [cache_key ] = {
779+ "previous_modulated_input" : None ,
780+ "previous_residual" : None ,
781+ "accumulated_rel_l1_distance" : 0.0 ,
782+ }
783+ current_cache = state .cache_dict [cache_key ]
784+
785+ # Make caching decision using per-cache values
786+ if state .cnt == 0 or state .cnt == state .num_steps - 1 :
787+ should_calc = True
788+ current_cache ["accumulated_rel_l1_distance" ] = 0.0
789+ else :
790+ if current_cache ["previous_modulated_input" ] is not None :
791+ prev_mod_input = current_cache ["previous_modulated_input" ]
792+ prev_mean = prev_mod_input .abs ().mean ()
793+
794+ if prev_mean .item () > 1e-9 :
795+ rel_l1_change = ((modulated_inp - prev_mod_input ).abs ().mean () / prev_mean ).cpu ().item ()
796+ else :
797+ rel_l1_change = 0.0 if modulated_inp .abs ().mean ().item () < 1e-9 else float ('inf' )
798+
799+ rescaled_distance = self .rescale_func (rel_l1_change )
800+ current_cache ["accumulated_rel_l1_distance" ] += rescaled_distance
801+
802+ if current_cache ["accumulated_rel_l1_distance" ] < self .config .rel_l1_thresh :
803+ should_calc = False
804+ else :
805+ should_calc = True
806+ current_cache ["accumulated_rel_l1_distance" ] = 0.0
807+ else :
808+ should_calc = True
809+ current_cache ["accumulated_rel_l1_distance" ] = 0.0
810+
811+ current_cache ["previous_modulated_input" ] = modulated_inp .clone ()
812+
813+ # Track unconditional sequence length for counter management
814+ if state .uncond_seq_len is None :
815+ state .uncond_seq_len = cache_key
816+ # Only increment counter when not processing unconditional (different seq len)
817+ if cache_key != state .uncond_seq_len :
818+ state .cnt += 1
819+ if state .cnt >= state .num_steps :
820+ state .cnt = 0
821+
822+ # Fast or slow path with per-cache residual
823+ if not should_calc and current_cache ["previous_residual" ] is not None :
772824 # Fast path: apply cached residual
773- processed_hidden_states = input_to_main_loop + state . previous_residual
825+ processed_hidden_states = input_to_main_loop + current_cache [ " previous_residual" ]
774826 else :
775827 # Slow path: full computation
776828 current_processing_states = input_to_main_loop
@@ -779,11 +831,8 @@ def _handle_lumina2_forward(
779831 current_processing_states , attention_mask_for_main_loop_arg , joint_rotary_emb , temb
780832 )
781833 processed_hidden_states = current_processing_states
782- # Cache the residual
783- state .previous_residual = processed_hidden_states - input_to_main_loop
784-
785- state .previous_modulated_input = modulated_inp
786- state .cnt += 1
834+ # Cache the residual in per-cache storage
835+ current_cache ["previous_residual" ] = processed_hidden_states - input_to_main_loop
787836
788837 # Apply final norm and reshape
789838 output_after_norm = module .norm_out (processed_hidden_states , temb )
@@ -881,13 +930,13 @@ def _handle_cogvideox_forward(
881930 # Make caching decision
882931 should_calc = self ._should_compute_full_transformer (state , modulated_inp )
883932
884- # Fast path: apply cached residuals (both encoder and hidden_states)
885- # Must have both residuals cached to use fast path
886- if not should_calc and state . previous_residual_encoder is not None :
933+ # Fast or slow path based on caching decision
934+ if not should_calc :
935+ # Fast path: apply cached residuals (both encoder and hidden_states)
887936 hidden_states = hidden_states + state .previous_residual
888937 encoder_hidden_states = encoder_hidden_states + state .previous_residual_encoder
889938 else :
890- # Slow path: full computation (also runs when encoder residual not yet cached)
939+ # Slow path: full computation
891940 ori_hidden_states = hidden_states .clone ()
892941 ori_encoder_hidden_states = encoder_hidden_states .clone ()
893942
@@ -928,7 +977,22 @@ def _handle_cogvideox_forward(
928977 hidden_states = module .norm_final (hidden_states )
929978 hidden_states = hidden_states [:, text_seq_length :]
930979
931- output = module .proj_out (hidden_states )
980+ # Final block
981+ hidden_states = module .norm_out (hidden_states , temb = emb )
982+ hidden_states = module .proj_out (hidden_states )
983+
984+ # Unpatchify
985+ p = module .config .patch_size
986+ p_t = module .config .patch_size_t
987+
988+ if p_t is None :
989+ output = hidden_states .reshape (batch_size , num_frames , height // p , width // p , - 1 , p , p )
990+ output = output .permute (0 , 1 , 4 , 2 , 5 , 3 , 6 ).flatten (5 , 6 ).flatten (3 , 4 )
991+ else :
992+ output = hidden_states .reshape (
993+ batch_size , (num_frames + p_t - 1 ) // p_t , height // p , width // p , - 1 , p_t , p , p
994+ )
995+ output = output .permute (0 , 1 , 5 , 4 , 2 , 6 , 3 , 7 ).flatten (6 , 7 ).flatten (4 , 5 ).flatten (1 , 2 )
932996
933997 if USE_PEFT_BACKEND :
934998 unscale_lora_layers (module , lora_scale )
0 commit comments