You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi community, I'm building a triton kernel which first loads some discontinuous indexes from one tensor, and loads actual data with these indexes from another tensor.
I'm trying to implement PagedAttention with triton if it helps.
I'm loading K blocks from KV Cache tensor. But since the K blocks are divided into pages, I need to load offsets in KV cache page by page.
I should also take the block offsets into consideration during computing the K indexes.
For example, assume the block table content be:
Block Table
Content
0
40
1
30
And the block size be 8, the expected offsets should be:
Token Index
Offset
0
40
1
41
2
42
3
43
4
44
5
45
6
46
7
47
8
30
9
31
10
32
11
33
12
34
13
35
14
36
15
37
I need to update the index tensor page by page.
I tried tl.view, but the result seems to be wrong. And to update only part of the Index, all I could come up with is use a global memory tensor, and tl.store with mask to update only part of the index, and tl.load immediately after all store.
Here's my kernel code:
# the block_idx tensor is setup to BLOCK_N before kernel launch
# key_block_idx = torch.zeros([BLOCK_N], dtype=torch.int64, device=key_cache.device)
# use global memory to build key block offs
for page_idx in range(0, PAGES_PER_BLOCK_N):
block_idx = tl.load(block_tables + (start_page_idx + page_idx) * stride_btb,
mask=start_page_idx + page_idx < max_block_pages, other=0)
current_block_idx = tl.full([BLOCK_SIZE], block_idx, tl.int64) * stride_kb + offs_kv_blocks * stride_kbs
tl.store(key_block_index + page_idx * BLOCK_SIZE + offs_kv_blocks, current_block_idx)
# trigger memory fence before direct load
tl.debug_barrier()
key_block_offs = tl.load(key_block_index + offs_page)
k_offs = cur_kv_head * stride_kh + key_head_offs[:, None] + key_block_offs[None, :]
k = tl.load(key_cache + k_offs, mask=(start_n + offs_page[None, :]) < cur_batch_seq_len, other=0.0)
tl.debug_barrier()
I know this is really slow, but how could I optimize this kernel?
The text was updated successfully, but these errors were encountered:
Hi community, I'm building a triton kernel which first loads some discontinuous indexes from one tensor, and loads actual data with these indexes from another tensor.
I'm trying to implement PagedAttention with triton if it helps.
I'm loading K blocks from KV Cache tensor. But since the K blocks are divided into pages, I need to load offsets in KV cache page by page.
I should also take the block offsets into consideration during computing the K indexes.
For example, assume the block table content be:
And the block size be 8, the expected offsets should be:
I need to update the index tensor page by page.
I tried
tl.view
, but the result seems to be wrong. And to update only part of the Index, all I could come up with is use a global memory tensor, andtl.store
withmask
to update only part of the index, andtl.load
immediately after all store.Here's my kernel code:
I know this is really slow, but how could I optimize this kernel?
The text was updated successfully, but these errors were encountered: