@@ -39,7 +39,6 @@ struct paged_attention_xpu_v1_impl_ {
3939 const int num_blocks) {
4040 constexpr int x = 16 / sizeof (scalar_t );
4141 const int num_queries_per_kv = num_heads / num_kv_heads;
42-
4342 int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE;
4443 int max_context_len_padded = (max_context_len + 15 ) & 0xFFFFFFF0 ;
4544 TORCH_CHECK ((max_context_len_padded * sizeof (float )) % 64 == 0 );
@@ -48,7 +47,7 @@ struct paged_attention_xpu_v1_impl_ {
4847 sycl::buffer<scalar_sycl_t , 1 > out_buf (
4948 (scalar_sycl_t *)out, num_seqs * num_heads * HEAD_SIZE);
5049 sycl::buffer<scalar_sycl_t , 1 > q_buf (
51- (scalar_sycl_t *)q, num_seqs * num_heads * HEAD_SIZE );
50+ (scalar_sycl_t *)q, num_seqs * q_stride );
5251 sycl::buffer<int , 1 > context_lens_buf (context_lens, num_seqs);
5352 sycl::buffer<int , 1 > block_tables_buf (
5453 block_tables, num_seqs * max_num_blocks_per_seq);
@@ -57,52 +56,49 @@ struct paged_attention_xpu_v1_impl_ {
5756 sycl::buffer<scalar_sycl_t , 1 > v_cache_buf (
5857 (scalar_sycl_t *)v_cache, num_blocks * kv_block_stride);
5958
60- auto e0 = task_q.memset (out, 0 , num_seqs * num_heads * HEAD_SIZE * sizeof ( scalar_t ));
61-
59+ auto e0 = task_q.memset (
60+ out, 0 , num_seqs * num_heads * HEAD_SIZE * sizeof ( scalar_t ));
6261
6362 size_t logits_stride = num_heads * max_context_len_padded;
6463 size_t logits_bytes = num_seqs * logits_stride * sizeof (float );
6564 float * logits = (float *)sycl::aligned_alloc_device (
6665 64 , logits_bytes, task_q.get_device (), task_q.get_context ());
6766 sycl::event reset_logits = task_q.memset (logits, 0 , logits_bytes);
68-
67+ reset_logits. wait ();
6968 auto e1 = task_q.submit ([&](auto & h) {
7069 sycl::accessor q_acc (q_buf, h, sycl::read_only);
7170 sycl::accessor k_cache_acc (k_cache_buf, h, sycl::read_only);
7271 sycl::accessor context_lens_acc (context_lens_buf, h, sycl::read_only);
7372 sycl::accessor block_tables_acc (block_tables_buf, h, sycl::read_only);
74- h.parallel_for (
75- sycl::range (num_seqs, num_heads, HEAD_SIZE / x),
76- [=](sycl::item<3 > item) {
77- size_t seq_idx = item[0 ];
78- size_t head_idx = item[1 ];
79- size_t x_idx = item[2 ];
80- int context_len = context_lens_acc[seq_idx];
81- const int block_num = (context_len + BLOCK_SIZE - 1 ) / BLOCK_SIZE;
82-
83- const int64_t kv_head_idx = head_idx / num_queries_per_kv;
84- size_t q_base_offset = seq_idx * q_stride + head_idx * HEAD_SIZE;
73+ h.parallel_for (sycl::range (num_seqs, num_heads), [=](sycl::item<2 > item) {
74+ size_t seq_idx = item[0 ];
75+ size_t head_idx = item[1 ];
76+ int context_len = context_lens_acc[seq_idx];
77+ const int block_num = (context_len + BLOCK_SIZE - 1 ) / BLOCK_SIZE;
8578
86- for (size_t block_idx = 0 ; block_idx < block_num; ++block_idx) {
87- const int32_t physical_block_idx = block_tables_acc
88- [block_idx + max_num_blocks_per_seq * seq_idx];
89- size_t k_base_offset = physical_block_idx * kv_block_stride +
90- kv_head_idx * kv_head_stride; // dim0,dim1
91- float * __restrict__ head_block_logits = logits +
92- seq_idx * logits_stride + head_idx * max_context_len_padded +
93- block_idx * BLOCK_SIZE;
94- for (int token_idx = 0 ; token_idx < BLOCK_SIZE; ++token_idx) {
95- for (int i = 0 ; i < x; ++i) {
96- head_block_logits[token_idx] +=
97- (float )q_acc[i + x_idx * x + q_base_offset] *
98- (float )k_cache_acc
99- [i + token_idx * x + BLOCK_SIZE * x_idx * x +
100- k_base_offset] *
101- scale;
102- }
79+ for (size_t block_idx = 0 ; block_idx < block_num; ++block_idx) {
80+ const int32_t physical_block_idx =
81+ block_tables_acc[block_idx + max_num_blocks_per_seq * seq_idx];
82+ const int64_t kv_head_idx = head_idx / num_queries_per_kv;
83+ size_t q_base_offset = seq_idx * q_stride + head_idx * HEAD_SIZE;
84+ size_t k_base_offset = physical_block_idx * kv_block_stride +
85+ kv_head_idx * kv_head_stride; // dim0,dim1
86+ float * __restrict__ head_block_logits = logits +
87+ seq_idx * logits_stride + head_idx * max_context_len_padded +
88+ block_idx * BLOCK_SIZE;
89+ for (int x_idx = 0 ; x_idx < HEAD_SIZE / x; ++x_idx) {
90+ for (int token_idx = 0 ; token_idx < BLOCK_SIZE; ++token_idx) {
91+ for (int i = 0 ; i < x; ++i) {
92+ head_block_logits
93+ [token_idx] += (float )q_acc[i + x_idx * x + q_base_offset] *
94+ (float )k_cache_acc[i + token_idx * x +
95+ BLOCK_SIZE * x_idx * x + k_base_offset] *
96+ scale;
10397 }
10498 }
105- });
99+ }
100+ }
101+ });
106102 });
107103 e1 .wait ();
108104
@@ -120,13 +116,15 @@ struct paged_attention_xpu_v1_impl_ {
120116 max_logit =
121117 max_logit >= head_logit_ptr[i] ? max_logit : head_logit_ptr[i];
122118 }
123- float sum = 0 ;
119+ float sum = 0 . f ;
124120 for (int i = 0 ; i < context_len; ++i) {
125- head_logit_ptr[i] = sycl::exp (head_logit_ptr[i] - max_logit);
126- sum += head_logit_ptr[i];
121+ float val = sycl::exp<float >(head_logit_ptr[i] - max_logit);
122+ head_logit_ptr[i] = val;
123+ sum += val;
127124 }
125+ const float inv_sum = 1 .f / (sum + 1e-6f );
128126 for (int i = 0 ; i < context_len; ++i) {
129- head_logit_ptr[i] /= sum ;
127+ head_logit_ptr[i] *= inv_sum ;
130128 }
131129 int remaining_seq_upper = block_num * BLOCK_SIZE;
132130 for (int i = context_len; i < remaining_seq_upper; ++i) {
@@ -164,8 +162,8 @@ struct paged_attention_xpu_v1_impl_ {
164162 BLOCK_SIZE * head_part_idx * 16 ;
165163 size_t out_base_offset = seq_idx * num_heads * HEAD_SIZE +
166164 head_idx * HEAD_SIZE + head_part_idx * 16 ;
167- for (int j = 0 ; j < BLOCK_SIZE ; ++j ) {
168- for (int i = 0 ; i < 16 ; ++i ) {
165+ for (int i = 0 ; i < 16 ; ++i ) {
166+ for (int j = 0 ; j < BLOCK_SIZE ; ++j ) {
169167 output_acc[i + out_base_offset] +=
170168 (scalar_sycl_t )(prob_vec_ptr[j] * (float )v_cache_acc[j + i * BLOCK_SIZE + v_base_offset]);
171169 }
0 commit comments