Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Needs help on performance optimization #2522

Open
void-main opened this issue Oct 20, 2023 · 3 comments
Open

Needs help on performance optimization #2522

void-main opened this issue Oct 20, 2023 · 3 comments

Comments

@void-main
Copy link
Contributor

Hi community, I'm building a triton kernel which first loads some discontinuous indexes from one tensor, and loads actual data with these indexes from another tensor.
I'm trying to implement PagedAttention with triton if it helps.

CleanShot 2023-10-20 at 11 43 48

I'm loading K blocks from KV Cache tensor. But since the K blocks are divided into pages, I need to load offsets in KV cache page by page.
I should also take the block offsets into consideration during computing the K indexes.

For example, assume the block table content be:

Block Table Content
0 40
1 30

And the block size be 8, the expected offsets should be:

Token Index Offset
0 40
1 41
2 42
3 43
4 44
5 45
6 46
7 47
8 30
9 31
10 32
11 33
12 34
13 35
14 36
15 37

I need to update the index tensor page by page.

I tried tl.view, but the result seems to be wrong. And to update only part of the Index, all I could come up with is use a global memory tensor, and tl.store with mask to update only part of the index, and tl.load immediately after all store.

Here's my kernel code:

        # the block_idx tensor is setup to BLOCK_N before kernel launch
        # key_block_idx = torch.zeros([BLOCK_N], dtype=torch.int64, device=key_cache.device)
        # use global memory to build key block offs
        for page_idx in range(0, PAGES_PER_BLOCK_N):
            block_idx = tl.load(block_tables + (start_page_idx + page_idx) * stride_btb,
                                mask=start_page_idx + page_idx < max_block_pages, other=0)
            current_block_idx = tl.full([BLOCK_SIZE], block_idx, tl.int64) * stride_kb + offs_kv_blocks * stride_kbs
            tl.store(key_block_index + page_idx * BLOCK_SIZE + offs_kv_blocks, current_block_idx)
        # trigger memory fence before direct load
        tl.debug_barrier()
        key_block_offs = tl.load(key_block_index + offs_page)

        k_offs = cur_kv_head * stride_kh + key_head_offs[:, None] + key_block_offs[None, :]
        k = tl.load(key_cache + k_offs, mask=(start_n + offs_page[None, :]) < cur_batch_seq_len, other=0.0)
        tl.debug_barrier()

I know this is really slow, but how could I optimize this kernel?

@Jokeren
Copy link
Contributor

Jokeren commented Oct 20, 2023

I know this is really slow

How slow it is comparing to the original implementation?

@void-main
Copy link
Contributor Author

I know this is really slow

How slow it is comparing to the original implementation?

Hi @Jokeren , I commented out this mask, and the performance boosts by approx. 30%+.

So seems this part of code is the bottleneck.

@MeJerry215
Copy link

@void-main any new progress?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants