@@ -291,17 +291,17 @@ def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None,
291291 def forward (self , x , context = None , attention_mask = None , timestep = None , pe = None , transformer_options = {}):
292292 shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = (self .scale_shift_table [None , None ].to (device = x .device , dtype = x .dtype ) + timestep .reshape (x .shape [0 ], timestep .shape [1 ], self .scale_shift_table .shape [0 ], - 1 )).unbind (dim = 2 )
293293
294- norm_x = comfy .ldm .common_dit .rms_norm (x )
295- attn1_input = torch .addcmul (norm_x , norm_x , scale_msa ).add_ (shift_msa )
296- attn1_result = self .attn1 (attn1_input , pe = pe , transformer_options = transformer_options )
297- x .addcmul_ (attn1_result , gate_msa )
294+ attn1_input = comfy .ldm .common_dit .rms_norm (x )
295+ attn1_input = torch .addcmul (attn1_input , attn1_input , scale_msa ).add_ (shift_msa )
296+ attn1_input = self .attn1 (attn1_input , pe = pe , transformer_options = transformer_options )
297+ x .addcmul_ (attn1_input , gate_msa )
298+ del attn1_input
298299
299300 x += self .attn2 (x , context = context , mask = attention_mask , transformer_options = transformer_options )
300301
301- norm_x = comfy .ldm .common_dit .rms_norm (x )
302- y = torch .addcmul (norm_x , norm_x , scale_mlp ).add_ (shift_mlp )
303- ff_result = self .ff (y )
304- x .addcmul_ (ff_result , gate_mlp )
302+ y = comfy .ldm .common_dit .rms_norm (x )
303+ y = torch .addcmul (y , y , scale_mlp ).add_ (shift_mlp )
304+ x .addcmul_ (self .ff (y ), gate_mlp )
305305
306306 return x
307307
@@ -336,16 +336,16 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2
336336 sin_vals = torch .cat ([torch .zeros_like (sin_vals [:, :, :padding_size ]), sin_vals ], dim = - 1 )
337337
338338 # Reshape and extract one value per pair (since repeat_interleave duplicates each value)
339- cos_vals = cos_vals .reshape (* cos_vals .shape [:2 ], - 1 , 2 )[..., 0 ] # [B, N, dim//2]
340- sin_vals = sin_vals .reshape (* sin_vals .shape [:2 ], - 1 , 2 )[..., 0 ] # [B, N, dim//2]
339+ cos_vals = cos_vals .reshape (* cos_vals .shape [:2 ], - 1 , 2 )[..., 0 ]. to ( out_dtype ) # [B, N, dim//2]
340+ sin_vals = sin_vals .reshape (* sin_vals .shape [:2 ], - 1 , 2 )[..., 0 ]. to ( out_dtype ) # [B, N, dim//2]
341341
342342 # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
343343 freqs_cis = torch .stack ([
344344 torch .stack ([cos_vals , - sin_vals ], dim = - 1 ),
345345 torch .stack ([sin_vals , cos_vals ], dim = - 1 )
346346 ], dim = - 2 ).unsqueeze (1 ) # [B, 1, N, dim//2, 2, 2]
347347
348- return freqs_cis . to ( out_dtype )
348+ return freqs_cis
349349
350350
351351class LTXVModel (torch .nn .Module ):
0 commit comments