Skip to content

Commit c4a6b38

Browse files
Lower ltxv mem usage to what it was before previous pr. (#10643)
Bring back qwen behavior to what it was before previous pr.
1 parent 4cd8818 commit c4a6b38

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

comfy/ldm/lightricks/model.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,17 +291,17 @@ def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None,
291291
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
292292
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
293293

294-
norm_x = comfy.ldm.common_dit.rms_norm(x)
295-
attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa)
296-
attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
297-
x.addcmul_(attn1_result, gate_msa)
294+
attn1_input = comfy.ldm.common_dit.rms_norm(x)
295+
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
296+
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
297+
x.addcmul_(attn1_input, gate_msa)
298+
del attn1_input
298299

299300
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
300301

301-
norm_x = comfy.ldm.common_dit.rms_norm(x)
302-
y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp)
303-
ff_result = self.ff(y)
304-
x.addcmul_(ff_result, gate_mlp)
302+
y = comfy.ldm.common_dit.rms_norm(x)
303+
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
304+
x.addcmul_(self.ff(y), gate_mlp)
305305

306306
return x
307307

@@ -336,16 +336,16 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2
336336
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
337337

338338
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
339-
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
340-
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
339+
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
340+
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
341341

342342
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
343343
freqs_cis = torch.stack([
344344
torch.stack([cos_vals, -sin_vals], dim=-1),
345345
torch.stack([sin_vals, cos_vals], dim=-1)
346346
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
347347

348-
return freqs_cis.to(out_dtype)
348+
return freqs_cis
349349

350350

351351
class LTXVModel(torch.nn.Module):

comfy/ldm/qwen_image/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def _forward(
415415
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
416416
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
417417
ids = torch.cat((txt_ids, img_ids), dim=1)
418-
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
418+
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
419419
del ids, txt_ids, img_ids
420420

421421
hidden_states = self.img_in(hidden_states)

0 commit comments

Comments
 (0)