@@ -42,7 +42,7 @@ def _fwd_context_paged_attention_kernel(
4242 sm_scale ,
4343 KV_GROUPS : tl .constexpr ,
4444 BLOCK_SIZE : tl .constexpr ,
45- BLOCK_DMODEL : tl .constexpr ,
45+ HEAD_DIM : tl .constexpr ,
4646 BLOCK_M : tl .constexpr ,
4747 BLOCK_N : tl .constexpr ,
4848):
@@ -66,38 +66,38 @@ def _fwd_context_paged_attention_kernel(
6666 for i in range (0 , cur_seq_idx ):
6767 prev_seq_len_sum += tl .load (context_lengths + i )
6868
69- q_offset = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
70- kv_offset = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
69+ offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
70+ offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
7171 Q_block_ptr = tl .make_block_ptr (
72- base = Q + q_offset ,
73- shape = (cur_seq_len , BLOCK_DMODEL ),
72+ base = Q + offset_q ,
73+ shape = (cur_seq_len , HEAD_DIM ),
7474 strides = (stride_qt , stride_qd ),
7575 offsets = (block_start_m * BLOCK_M , 0 ),
76- block_shape = (BLOCK_M , BLOCK_DMODEL ),
76+ block_shape = (BLOCK_M , HEAD_DIM ),
7777 order = (1 , 0 ),
7878 )
7979 K_block_ptr = tl .make_block_ptr (
80- base = K + kv_offset ,
81- shape = (BLOCK_DMODEL , cur_seq_len ),
80+ base = K + offset_kv ,
81+ shape = (HEAD_DIM , cur_seq_len ),
8282 strides = (stride_kd , stride_kt ),
8383 offsets = (0 , 0 ),
84- block_shape = (BLOCK_DMODEL , BLOCK_N ),
84+ block_shape = (HEAD_DIM , BLOCK_N ),
8585 order = (0 , 1 ),
8686 )
8787 V_block_ptr = tl .make_block_ptr (
88- base = V + kv_offset ,
89- shape = (cur_seq_len , BLOCK_DMODEL ),
88+ base = V + offset_kv ,
89+ shape = (cur_seq_len , HEAD_DIM ),
9090 strides = (stride_vt , stride_vd ),
9191 offsets = (0 , 0 ),
92- block_shape = (BLOCK_N , BLOCK_DMODEL ),
92+ block_shape = (BLOCK_N , HEAD_DIM ),
9393 order = (1 , 0 ),
9494 )
9595 O_block_ptr = tl .make_block_ptr (
96- base = O + q_offset ,
97- shape = (cur_seq_len , BLOCK_DMODEL ),
96+ base = O + offset_q ,
97+ shape = (cur_seq_len , HEAD_DIM ),
9898 strides = (stride_ot , stride_od ),
9999 offsets = (block_start_m * BLOCK_M , 0 ),
100- block_shape = (BLOCK_M , BLOCK_DMODEL ),
100+ block_shape = (BLOCK_M , HEAD_DIM ),
101101 order = (1 , 0 ),
102102 )
103103
@@ -108,13 +108,13 @@ def _fwd_context_paged_attention_kernel(
108108 # as we have BLOCK_M the same size as the block size.
109109 cur_block_table_idx = block_start_m
110110 cur_block_id = tl .load (block_table_ptr + cur_block_table_idx * stride_btb )
111- kvcache_offset = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
111+ offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
112112
113113 offsets_m = block_start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
114114 offsets_n = tl .arange (0 , BLOCK_N )
115115 m_i = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
116116 l_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 )
117- acc = tl .zeros ([BLOCK_M , BLOCK_DMODEL ], dtype = tl .float32 )
117+ acc = tl .zeros ([BLOCK_M , HEAD_DIM ], dtype = tl .float32 )
118118
119119 if block_start_m * BLOCK_M >= cur_seq_len :
120120 return
@@ -152,43 +152,41 @@ def _fwd_context_paged_attention_kernel(
152152
153153 if cur_head_idx % KV_GROUPS == 0 :
154154 # Copy k to corresponding cache block
155- kd_offsets = tl .arange (0 , BLOCK_DMODEL )
156- kt_offsets = block_start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
157- k_offsets = K + kv_offset + kd_offsets [:, None ] * stride_kd + kt_offsets [None , :] * stride_kt
158- k = tl .load (k_offsets , mask = kt_offsets [None , :] < cur_seq_len , other = 0.0 )
159- kcached_offsets = tl .arange (0 , BLOCK_DMODEL )
160- kcachebs_offsets = tl .arange (0 , BLOCK_SIZE )
161- kcache_offsets = (
155+ offsets_dmodel = tl .arange (0 , HEAD_DIM )
156+ offsets_kt = block_start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
157+ offsets_k = K + offset_kv + offsets_dmodel [:, None ] * stride_kd + offsets_kt [None , :] * stride_kt
158+ k = tl .load (offsets_k , mask = offsets_kt [None , :] < cur_seq_len , other = 0.0 )
159+ offsets_kcachebs = tl .arange (0 , BLOCK_SIZE )
160+ offsets_kcache = (
162161 KCache
163- + kvcache_offset
164- + kcached_offsets [:, None ] * stride_cached
165- + kcachebs_offsets [None , :] * stride_cachebs
162+ + offset_kvcache
163+ + offsets_dmodel [:, None ] * stride_cached
164+ + offsets_kcachebs [None , :] * stride_cachebs
166165 )
167- tl .store (kcache_offsets , k , mask = kcachebs_offsets [None , :] < cur_seq_len - block_start_m * BLOCK_SIZE )
166+ tl .store (offsets_kcache , k , mask = offsets_kcachebs [None , :] < cur_seq_len - block_start_m * BLOCK_SIZE )
168167 # Copy v to corresponding cache block
169- vd_offsets = kd_offsets
170- vt_offsets = block_start_m * BLOCK_N + tl .arange (0 , BLOCK_N )
171- v_offsets = V + kv_offset + vt_offsets [:, None ] * stride_vt + vd_offsets [None , :] * stride_vd
172- v = tl .load (v_offsets , mask = vt_offsets [:, None ] < cur_seq_len , other = 0.0 )
173- vcached_offsets = kcached_offsets
174- vcachebs_offsets = kcachebs_offsets
175- vcache_offsets = (
168+ offsets_vd = offsets_dmodel
169+ offsets_vt = block_start_m * BLOCK_N + tl .arange (0 , BLOCK_N )
170+ offsets_v = V + offset_kv + offsets_vt [:, None ] * stride_vt + offsets_vd [None , :] * stride_vd
171+ v = tl .load (offsets_v , mask = offsets_vt [:, None ] < cur_seq_len , other = 0.0 )
172+ offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
173+ offsets_vcache = (
176174 VCache
177- + kvcache_offset
178- + vcachebs_offsets [:, None ] * stride_cachebs
179- + vcached_offsets [None , :] * stride_cached
175+ + offset_kvcache
176+ + offsets_vcachebs [:, None ] * stride_cachebs
177+ + offsets_dmodel [None , :] * stride_cached
180178 )
181- tl .store (vcache_offsets , v , mask = vcachebs_offsets [:, None ] < cur_seq_len - block_start_m * BLOCK_SIZE )
179+ tl .store (offsets_vcache , v , mask = offsets_vcachebs [:, None ] < cur_seq_len - block_start_m * BLOCK_SIZE )
182180
183181 return
184182
185183
186184def context_attention_unpadded (
187- q : torch .Tensor , # [num_tokens, num_heads, head_size ]
188- k : torch .Tensor , # [num_tokens, num_kv_heads, head_size ]
189- v : torch .Tensor , # [num_tokens, num_kv_heads, head_size ]
190- k_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_size , block_size]
191- v_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_size , block_size]
185+ q : torch .Tensor , # [num_tokens, num_heads, head_dim ]
186+ k : torch .Tensor , # [num_tokens, num_kv_heads, head_dim ]
187+ v : torch .Tensor , # [num_tokens, num_kv_heads, head_dim ]
188+ k_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_dim , block_size]
189+ v_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_dim , block_size]
192190 context_lengths : torch .Tensor , # [num_seqs]
193191 block_tables : torch .Tensor , # [num_seqs, max_blocks_per_sequence],
194192 block_size : int ,
@@ -254,7 +252,7 @@ def context_attention_unpadded(
254252 sm_scale ,
255253 num_kv_group ,
256254 block_size ,
257- BLOCK_DMODEL = Lk ,
255+ HEAD_DIM = Lk ,
258256 BLOCK_M = BLOCK_M ,
259257 BLOCK_N = BLOCK_N ,
260258 )
0 commit comments