Skip to content

opt flashinfer mla cat #5822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,22 +332,38 @@ def forward_extend(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):

cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)

# Save kv cache
if save_kv_cache and k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if k_rope is not None:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, cache_loc, k, k_rope
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if q_rope is not None:
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)

if self.forward_metadata.use_ragged:
# ragged prefill
if q_rope is not None:
q = torch.cat([q, q_rope], dim=-1)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim),
Expand All @@ -358,11 +374,19 @@ def forward_extend(
)
else:
# mla paged prefill
if q_rope is None:
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q, q_rope = (
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
)
o = q.new_empty(q.shape)
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
q,
q_rope,
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
out=o,
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
Expand All @@ -375,27 +399,50 @@ def forward_decode(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
decode_wrapper = self.forward_metadata.decode_wrapper
cache_loc = forward_batch.out_cache_loc

if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)

if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = reshaped_q[:, :, : layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim :]

k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
o = q_nope.new_empty(q_nope.shape)
o = decode_wrapper.run(
reshaped_q[:, :, : layer.v_head_dim],
reshaped_q[:, :, layer.v_head_dim :],
q_nope,
q_rope,
reshaped_k[:, :, : layer.v_head_dim],
reshaped_k[:, :, layer.v_head_dim :],
out=o,
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def forward_absorb(

q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

if self.attention_backend == "fa3":
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
)
Expand Down