44
55from comfy .ldm .modules .attention import optimized_attention
66import comfy .model_management
7- import logging
87
98
109def attention (q : Tensor , k : Tensor , v : Tensor , pe : Tensor , mask = None , transformer_options = {}) -> Tensor :
@@ -14,6 +13,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
1413 x = optimized_attention (q , k , v , heads , skip_reshape = True , mask = mask , transformer_options = transformer_options )
1514 return x
1615
16+
1717def rope (pos : Tensor , dim : int , theta : int ) -> Tensor :
1818 assert dim % 2 == 0
1919 if comfy .model_management .is_device_mps (pos .device ) or comfy .model_management .is_intel_xpu () or comfy .model_management .is_directml_enabled ():
@@ -28,20 +28,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
2828 out = rearrange (out , "b n d (i j) -> b n d i j" , i = 2 , j = 2 )
2929 return out .to (dtype = torch .float32 , device = pos .device )
3030
31+ def apply_rope1 (x : Tensor , freqs_cis : Tensor ):
32+ x_ = x .to (dtype = freqs_cis .dtype ).reshape (* x .shape [:- 1 ], - 1 , 1 , 2 )
3133
32- try :
33- import comfy .quant_ops
34- apply_rope = comfy .quant_ops .ck .apply_rope
35- apply_rope1 = comfy .quant_ops .ck .apply_rope1
36- except :
37- logging .warning ("No comfy kitchen, using old apply_rope functions." )
38- def apply_rope1 (x : Tensor , freqs_cis : Tensor ):
39- x_ = x .to (dtype = freqs_cis .dtype ).reshape (* x .shape [:- 1 ], - 1 , 1 , 2 )
40-
41- x_out = freqs_cis [..., 0 ] * x_ [..., 0 ]
42- x_out .addcmul_ (freqs_cis [..., 1 ], x_ [..., 1 ])
34+ x_out = freqs_cis [..., 0 ] * x_ [..., 0 ]
35+ x_out .addcmul_ (freqs_cis [..., 1 ], x_ [..., 1 ])
4336
44- return x_out .reshape (* x .shape ).type_as (x )
37+ return x_out .reshape (* x .shape ).type_as (x )
4538
46- def apply_rope (xq : Tensor , xk : Tensor , freqs_cis : Tensor ):
47- return apply_rope1 (xq , freqs_cis ), apply_rope1 (xk , freqs_cis )
39+ def apply_rope (xq : Tensor , xk : Tensor , freqs_cis : Tensor ):
40+ return apply_rope1 (xq , freqs_cis ), apply_rope1 (xk , freqs_cis )
0 commit comments