Skip to content

Conversation

@krishnaraj36
Copy link
Contributor

@krishnaraj36 krishnaraj36 commented Oct 22, 2024

Improvements -
Added Tranpose to K for better Vectorization during Matmul. Improved Load Schedule.
Improved a bit more than 2x is most cases.
Llama-2 7B observation
-----------kernel----------------baseline----------optimized

  • ---batch_prefill_ragged_kv------15 ms-------------7.1 ms

This PR fixes the issue addressed in the PR #17446. The correctness issue is caused by incorrect code-generation during the unroll phase. Thus, we removed the explicit unroll and noticed little to no performance degradation.

We generated OpenCL kernels extracting the generated modules by setting num_qo_heads=28 in
https://github.qualcomm.com/gpgpu/apache-tvm/blob/85e15d494d5a42360859941cbc972c4f175c3b94/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py#L36
Original PR Codegen

int cur_L_3 = ((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) / 7) + (((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) >> 31)) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]);
if (cur_L_3 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
    vstore4((convert_half4((O_local[3] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)]))))), 0, output + (((((cur_L_3 * 3584) + ((convert_int(get_group_id(1))) * 896)) + ((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) + (7 & (((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) >> 31))) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)) + 4));
}
int cur_L_4 = ((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) - 2147483637) / 7) - -306783377) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]);
if (cur_L_4 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
    vstore4((convert_half4((O_local[4] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)]))))), 0, output + ((((cur_L_4 * 3584) + ((convert_int(get_group_id(1))) * 896)) + (((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) - 2147483637) % 7) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)));
}

In the O_store block we notice large and incorrect pointer offsets were being generated during subsequent stages of unroll. This can be indirectly noted zero elements contained in the output and compute instability.

Fusing the unroll loops to unroll together doesn't seem to resolve this.

Oddly enough, the initial test case doesn't seem to trigger the issue and works as intended.

int cur_L_3 = ((((((convert_int(get_local_id(0))) >> 4) + ((LH_start + 1) >> 2)) >> 1) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]) + (convert_int(get_local_id(1))));
if (cur_L_3 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
    vstore4((convert_half4((O_local[3] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)]))))), 0, output + (((((cur_L_3 * 4096) + ((convert_int(get_group_id(1))) * 1024)) + (((((((convert_int(get_local_id(0))) >> 4) * 4) + (LH_start & 7)) + 1) & 7) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)) + 4));
}
int cur_L_4 = ((((((convert_int(get_local_id(0))) >> 4) + ((LH_start + 2) >> 2)) >> 1) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]) + (convert_int(get_local_id(1))));
 if (cur_L_4 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
    vstore4((convert_half4((O_local[4] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)]))))), 0, output + ((((cur_L_4 * 4096) + ((convert_int(get_group_id(1))) * 1024)) + (((((((convert_int(get_local_id(0))) >> 4) * 4) + (LH_start & 7)) + 2) & 7) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)));
}

Improvements

Added Tranpose to K for better Vectorization during Matmul.
Improved Load Schedule.
Improved a bit more than 2x is most cases.
Llama-2 7B observation
-----------kernel----------------baseline----------optimized-
---batch_prefill_ragged_kv------15 ms-------------7.1 ms
@krishnaraj36
Copy link
Contributor Author

@MasterJH5574 @tqchen
We have fixed the issue raise in PR (#17466).
Can you please look at this PR.

Copy link
Contributor

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @krishnaraj36 so much for the fix!

@MasterJH5574
Copy link
Contributor

I have also observed the “large and incorrect” pointer offset before but I didn't get time to nail down the issue. Roughly I remember it's generated by some floordiv simplification in src/tir/transforms/lower_intrin.cc.

@krishnaraj36
Copy link
Contributor Author

Thank you @krishnaraj36 so much for the fix!
@MasterJH5574
There is only one change(removed sch.unroll(xi) ) on previous commit which was reverted.

@srkreddy1238 srkreddy1238 merged commit e3e27f5 into apache:main Oct 28, 2024
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
Improvements

Added Tranpose to K for better Vectorization during Matmul.
Improved Load Schedule.
Improved a bit more than 2x is most cases.
Llama-2 7B observation
-----------kernel----------------baseline----------optimized-
---batch_prefill_ragged_kv------15 ms-------------7.1 ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants