Skip to content

Conversation

@MasterJH5574
Copy link
Collaborator

This PR adds q/k position information to batch prefill/decode kernels. More specifically, the kernel now accepts two additional arrays:

  • q_rope_position with shape (total_q_len,), denoting the in-sequence position of each position in the input q.
  • k_rope_pos_offset with shape (num_sequence,), denoting the start position of each sequence in k.

These two arrays helps on-the-fly calculate RoPE in multi-level cases.

Tests test_batch_prefill and test_batch_decode can pass. Performance is not validated yet. Per discussion with Zihao, this change is not very likely to incur significant perf regression.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 21, 2024

I'll merge this into the mainline after #75 gets merged.

@MasterJH5574 MasterJH5574 force-pushed the qk-rope-info branch 2 times, most recently from 5b189f5 to 47686ef Compare January 29, 2024 18:49
@yzh119
Copy link
Collaborator

yzh119 commented Jan 31, 2024

Sorry about the new conflicts, I'll take care of them tmr.

This PR adds q/k position information to batch prefill/decode
kernels. More specifically, the kernel now accepts two
additional arrays:
* `q_rope_position` with shape `(total_q_len,)`, denoting the
in-sequence position of each position in the input q.
* `k_rope_pos_offset` with shape `(num_sequence,)`, denoting
the start position of each sequence in k.

These two arrays helps on-the-fly calculate RoPE in multi-level
cases.

Tests `test_batch_prefill` and `test_batch_decode` can pass.
Performance is not validated yet. Per discussion with Zihao,
this change is not very likely to incur significant perf
regression.
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @MasterJH5574 !

@yzh119 yzh119 merged commit a389ed4 into flashinfer-ai:main Feb 1, 2024
yzh119 added a commit that referenced this pull request Feb 16, 2024
This PR fixes #113, which is because #69 changed the
`BatchPrefillWithPagedKVCacheWrapperDispatched` signature, and
`flashinfer_decl.h` was not updated accordingly.

Also fixes some tiny format issues in #111.
diptorupd referenced this pull request in ROCm/flashinfer Sep 29, 2025
Adds a common wrapper function to mma_ops.hpp for hgemm kernels that works for both CUDA and HIP. Replaces
`mma_sync_m16n16k16_row_col_f16f16f32`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants