Skip to content

Commit aa9d759

Browse files
Switch ltxv to use the pytorch RMSNorm. (Comfy-Org#7897)
1 parent c6c19e9 commit aa9d759

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

comfy/ldm/lightricks/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from torch import nn
33
import comfy.ldm.modules.attention
4-
from comfy.ldm.genmo.joint_model.layers import RMSNorm
54
import comfy.ldm.common_dit
65
from einops import rearrange
76
import math
@@ -262,8 +261,8 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
262261
self.heads = heads
263262
self.dim_head = dim_head
264263

265-
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
266-
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
264+
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
265+
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
267266

268267
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
269268
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)

0 commit comments

Comments
 (0)