Skip to content

Commit b50dbe2

Browse files
committed
feat: support ascend moe_gating_topk_softmax
1 parent eec97b1 commit b50dbe2

File tree

7 files changed

+39
-19
lines changed

7 files changed

+39
-19
lines changed

lmdeploy/pytorch/kernels/ascend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .apply_rotary_pos_emb import apply_rotary_pos_emb
44
from .fill_kv_cache import fill_kv_cache
55
from .fused_rotary_emb import fused_rotary_emb
6+
from .moe_gating_topk_softmax import moe_gating_topk_softmax
67
from .paged_attention_fwd import paged_attention_fwd
78
from .rms_norm import rms_norm
89

@@ -12,5 +13,6 @@
1213
'fused_rotary_emb',
1314
'fill_kv_cache',
1415
'paged_attention_fwd',
16+
'moe_gating_topk_softmax',
1517
'multinomial_sampling',
1618
]

lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ def apply_rotary_pos_emb(
2626
setattr(context, 'sin', sin)
2727
cached_cos = context.cos if context else cos
2828
cached_sin = context.sin if context else sin
29-
ext_ops.apply_rotary_pos_emb(
30-
query_states_reshaped, key_states_reshaped, cached_cos, cached_sin,
31-
None, None, None
32-
)
29+
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
30+
cached_cos, cached_sin, None, None, None)
3331
if q_embed is None:
3432
q_embed = query_states
3533
else:

lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ def fill_kv_cache(
1616
context: None,
1717
):
1818
"""fill key/value state to cache for paged attention."""
19-
ext_ops.fill_kv_cache(key_states, value_states, key_caches,
20-
value_caches, context.kv_start_indices)
19+
ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches,
20+
context.kv_start_indices)

lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ def fused_rotary_emb(
2121
position_ids = position_ids.squeeze(0).unsqueeze(-1)
2222
pos_freq = position_ids / scaling_factor * inv_freq
2323
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
24-
cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1)
25-
.repeat(1, 1, 1, 2).to(query_states.dtype))
26-
sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1)
27-
.repeat(1, 1, 1, 2).to(query_states.dtype))
24+
cos = (torch.cos(pos_freq).view(batch, seqlen, 1,
25+
-1).repeat(1, 1, 1,
26+
2).to(query_states.dtype))
27+
sin = (torch.sin(pos_freq).view(batch, seqlen, 1,
28+
-1).repeat(1, 1, 1,
29+
2).to(query_states.dtype))
2830
if context:
2931
setattr(context, 'cos', cos)
3032
setattr(context, 'sin', sin)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import infer_ext.ops as ext_ops
3+
import torch
4+
from torch import Tensor
5+
6+
7+
def moe_gating_topk_softmax(router_logits: Tensor, topk: int):
8+
routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(
9+
router_logits, topk)
10+
return routing_weights.to(torch.float32), selected_experts.to(torch.int64)

lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def flash_context_attention(
2121
):
2222
num_q_heads, dim = query_states.shape[1:3]
2323
num_kv_heads = value_states.shape[1]
24-
batch = q_start_loc.shape[0]
24+
batch = q_start_loc.shape[0]
2525

2626
for i in range(batch):
2727
if torch.equal(q_seq_len[i], kv_seq_len[i]):
@@ -30,30 +30,32 @@ def flash_context_attention(
3030
query_states,
3131
key_states,
3232
value_states,
33-
q_start_loc[i:i+1],
34-
q_seq_len[i:i+1],
33+
q_start_loc[i:i + 1],
34+
q_seq_len[i:i + 1],
3535
num_q_heads,
3636
num_kv_heads,
37-
context.attention_mask[i:i+1],
37+
context.attention_mask[i:i + 1],
3838
)
3939
else:
4040
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
41-
value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
41+
value_cache = value_cache.reshape(1, kv_cache_len,
42+
num_kv_heads * dim)
4243
ext_ops.paged_prefill_attention(
4344
attn_output,
4445
query_states,
4546
key_cache,
4647
value_cache,
4748
block_offsets,
4849
block_size,
49-
q_start_loc[i:i+1],
50-
q_seq_len[i:i+1],
51-
kv_seq_len[i:i+1],
50+
q_start_loc[i:i + 1],
51+
q_seq_len[i:i + 1],
52+
kv_seq_len[i:i + 1],
5253
num_q_heads,
5354
num_kv_heads,
54-
context.attention_mask[i:i+1],
55+
context.attention_mask[i:i + 1],
5556
)
5657

58+
5759
def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
5860
block_offsets, block_size):
5961
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
@@ -69,6 +71,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
6971
num_kv_heads,
7072
)
7173

74+
7275
def paged_attention_fwd(
7376
query_states: Tensor,
7477
key_states: torch.Tensor,
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .dispatcher import FunctionDispatcher
3+
4+
moe_gating_topk_softmax = FunctionDispatcher(
5+
'moe_gating_topk_softmax').make_caller()

0 commit comments

Comments
 (0)