@@ -140,22 +140,18 @@ title: vLLM Paged Attention
140
140
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
141
141
```
142
142
143
- <figure markdown =" span " >
144
- { align="center" alt="query" width="70%" }
145
- <figcaption>
146
- </figcaption >
147
- </figure >
143
+ <figure markdown =" span " >
144
+ ![ ] ( ../../assets/kernel/query.png ) { align="center" alt="query" width="70%" }
145
+ </figure >
148
146
149
147
- Each thread defines its own ` q_ptr ` which points to the assigned
150
148
query token data on global memory. For example, if ` VEC_SIZE ` is 4
151
149
and ` HEAD_SIZE ` is 128, the ` q_ptr ` points to data that contains
152
150
total of 128 elements divided into 128 / 4 = 32 vecs.
153
151
154
- <figure markdown =" span " >
155
- ![ ] ( ../../assets/kernel/q_vecs.png ) { align="center" alt="q_vecs" width="70%" }
156
- <figcaption >
157
- </figcaption >
158
- </figure >
152
+ <figure markdown =" span " >
153
+ ![ ] ( ../../assets/kernel/q_vecs.png ) { align="center" alt="q_vecs" width="70%" }
154
+ </figure >
159
155
160
156
``` cpp
161
157
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
@@ -192,11 +188,9 @@ title: vLLM Paged Attention
192
188
points to key token data based on ` k_cache ` at assigned block,
193
189
assigned head and assigned token.
194
190
195
- <figure markdown =" span " >
196
- ![ ] ( ../../assets/kernel/key.png ) { align="center" alt="key" width="70%" }
197
- <figcaption >
198
- </figcaption >
199
- </figure >
191
+ <figure markdown =" span " >
192
+ ![ ] ( ../../assets/kernel/key.png ) { align="center" alt="key" width="70%" }
193
+ </figure >
200
194
201
195
- The diagram above illustrates the memory layout for key data. It
202
196
assumes that the ` BLOCK_SIZE ` is 16, ` HEAD_SIZE ` is 128, ` x ` is
@@ -209,11 +203,9 @@ title: vLLM Paged Attention
209
203
elements for one token) that will be processed by 2 threads (one
210
204
thread group) separately.
211
205
212
- <figure markdown =" span " >
213
- ![ ] ( ../../assets/kernel/k_vecs.png ) { align="center" alt="k_vecs" width="70%" }
214
- <figcaption >
215
- </figcaption >
216
- </figure >
206
+ <figure markdown =" span " >
207
+ ![ ] ( ../../assets/kernel/k_vecs.png ) { align="center" alt="k_vecs" width="70%" }
208
+ </figure >
217
209
218
210
``` cpp
219
211
K_vec k_vecs[NUM_VECS_PER_THREAD]
@@ -372,20 +364,14 @@ title: vLLM Paged Attention
372
364
373
365
<figure markdown =" span " >
374
366
![ ] ( ../../assets/kernel/value.png ) { align="center" alt="value" width="70%" }
375
- <figcaption >
376
- </figcaption >
377
367
</figure >
378
368
379
369
<figure markdown =" span " >
380
370
![ ] ( ../../assets/kernel/logits_vec.png ) { align="center" alt="logits_vec" width="50%" }
381
- <figcaption >
382
- </figcaption >
383
371
</figure >
384
372
385
373
<figure markdown =" span " >
386
374
![ ] ( ../../assets/kernel/v_vec.png ) { align="center" alt="v_vec" width="70%" }
387
- <figcaption >
388
- </figcaption >
389
375
</figure >
390
376
391
377
- Now we need to retrieve the value data and perform dot multiplication
0 commit comments