1+ from typing import Tuple , Union
2+
3+ import torch
4+ import torch_npu
5+
6+ from ..utils import log_replace_info
7+
8+
9+ def npu_apply_rotary_emb (
10+ x : torch .Tensor ,
11+ freqs_cis : Union [torch .Tensor , Tuple [torch .Tensor ]],
12+ use_real : bool = True ,
13+ use_real_unbind_dim : int = - 1 ,
14+ sequence_dim : int = 2 ,
15+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
16+ """
17+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
18+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
19+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
20+ tensors contain rotary embeddings and are returned as real tensors.
21+
22+ Args:
23+ x (`torch.Tensor`):
24+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
25+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
26+
27+ Returns:
28+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
29+ """
30+ if use_real :
31+ cos , sin = freqs_cis # [S, D]
32+ if sequence_dim == 2 :
33+ cos = cos [None , None , :, :]
34+ sin = sin [None , None , :, :]
35+ elif sequence_dim == 1 :
36+ cos = cos [None , :, None , :]
37+ sin = sin [None , :, None , :]
38+ else :
39+ raise ValueError (f"`sequence_dim={ sequence_dim } ` but should be 1 or 2." )
40+
41+ cos , sin = cos .to (x .device ), sin .to (x .device )
42+
43+ if use_real_unbind_dim == - 1 :
44+ # Used for flux, cogvideox, hunyuan-dit
45+ rotary_mode = "interleave"
46+ elif use_real_unbind_dim == - 2 :
47+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
48+ rotary_mode = "half"
49+ else :
50+ raise ValueError (f"`use_real_unbind_dim={ use_real_unbind_dim } ` but should be -1 or -2." )
51+ out = torch_npu .npu_rotary_mul (x , cos , sin , rotary_mode = rotary_mode ).to (x .dtype )
52+
53+ return out
54+ else :
55+ # used for lumina
56+ x_rotated = torch .view_as_complex (x .float ().reshape (* x .shape [:- 1 ], - 1 , 2 ))
57+ freqs_cis = freqs_cis .unsqueeze (2 )
58+ x_out = torch .view_as_real (x_rotated * freqs_cis ).flatten (3 )
59+
60+ return x_out .type_as (x )
61+
62+
63+ def replace_func ():
64+ from diffusers .models import embeddings
65+ from diffusers .models .transformers import transformer_flux
66+
67+ embeddings .apply_rotary_emb = npu_apply_rotary_emb
68+ transformer_flux .apply_rotary_emb = npu_apply_rotary_emb
69+
70+
71+ def replace_npu_rotary_mul ():
72+ replace_func ()
73+ log_replace_info ("apply_rotary_emb" , "npu_rotary_mul" )
0 commit comments