Skip to content

Commit 1618002

Browse files
Revert "Use rope functions from comfy kitchen. (Comfy-Org#11647)" (Comfy-Org#11648)
This reverts commit 6ef85c4.
1 parent 6ef85c4 commit 1618002

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

comfy/ldm/flux/math.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from comfy.ldm.modules.attention import optimized_attention
66
import comfy.model_management
7-
import logging
87

98

109
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
@@ -14,6 +13,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
1413
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
1514
return x
1615

16+
1717
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
1818
assert dim % 2 == 0
1919
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
@@ -28,20 +28,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
2828
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
2929
return out.to(dtype=torch.float32, device=pos.device)
3030

31+
def apply_rope1(x: Tensor, freqs_cis: Tensor):
32+
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
3133

32-
try:
33-
import comfy.quant_ops
34-
apply_rope = comfy.quant_ops.ck.apply_rope
35-
apply_rope1 = comfy.quant_ops.ck.apply_rope1
36-
except:
37-
logging.warning("No comfy kitchen, using old apply_rope functions.")
38-
def apply_rope1(x: Tensor, freqs_cis: Tensor):
39-
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
40-
41-
x_out = freqs_cis[..., 0] * x_[..., 0]
42-
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
34+
x_out = freqs_cis[..., 0] * x_[..., 0]
35+
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
4336

44-
return x_out.reshape(*x.shape).type_as(x)
37+
return x_out.reshape(*x.shape).type_as(x)
4538

46-
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
47-
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
39+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
40+
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

0 commit comments

Comments
 (0)