Skip to content

Commit 4066433

Browse files
committed
fix result
1 parent de60f86 commit 4066433

File tree

5 files changed

+58
-53
lines changed

5 files changed

+58
-53
lines changed

csrc/xpu/attention_xpu.cpp

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ struct paged_attention_xpu_v1_impl_ {
3939
const int num_blocks) {
4040
constexpr int x = 16 / sizeof(scalar_t);
4141
const int num_queries_per_kv = num_heads / num_kv_heads;
42-
4342
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE;
4443
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0;
4544
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0);
@@ -48,7 +47,7 @@ struct paged_attention_xpu_v1_impl_ {
4847
sycl::buffer<scalar_sycl_t, 1> out_buf(
4948
(scalar_sycl_t*)out, num_seqs * num_heads * HEAD_SIZE);
5049
sycl::buffer<scalar_sycl_t, 1> q_buf(
51-
(scalar_sycl_t*)q, num_seqs * num_heads * HEAD_SIZE);
50+
(scalar_sycl_t*)q, num_seqs * q_stride);
5251
sycl::buffer<int, 1> context_lens_buf(context_lens, num_seqs);
5352
sycl::buffer<int, 1> block_tables_buf(
5453
block_tables, num_seqs * max_num_blocks_per_seq);
@@ -57,52 +56,49 @@ struct paged_attention_xpu_v1_impl_ {
5756
sycl::buffer<scalar_sycl_t, 1> v_cache_buf(
5857
(scalar_sycl_t*)v_cache, num_blocks * kv_block_stride);
5958

60-
auto e0 = task_q.memset(out, 0, num_seqs * num_heads * HEAD_SIZE * sizeof(scalar_t));
61-
59+
auto e0 = task_q.memset(
60+
out, 0, num_seqs * num_heads * HEAD_SIZE * sizeof(scalar_t));
6261

6362
size_t logits_stride = num_heads * max_context_len_padded;
6463
size_t logits_bytes = num_seqs * logits_stride * sizeof(float);
6564
float* logits = (float*)sycl::aligned_alloc_device(
6665
64, logits_bytes, task_q.get_device(), task_q.get_context());
6766
sycl::event reset_logits = task_q.memset(logits, 0, logits_bytes);
68-
67+
reset_logits.wait();
6968
auto e1 = task_q.submit([&](auto& h) {
7069
sycl::accessor q_acc(q_buf, h, sycl::read_only);
7170
sycl::accessor k_cache_acc(k_cache_buf, h, sycl::read_only);
7271
sycl::accessor context_lens_acc(context_lens_buf, h, sycl::read_only);
7372
sycl::accessor block_tables_acc(block_tables_buf, h, sycl::read_only);
74-
h.parallel_for(
75-
sycl::range(num_seqs, num_heads, HEAD_SIZE / x),
76-
[=](sycl::item<3> item) {
77-
size_t seq_idx = item[0];
78-
size_t head_idx = item[1];
79-
size_t x_idx = item[2];
80-
int context_len = context_lens_acc[seq_idx];
81-
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
82-
83-
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
84-
size_t q_base_offset = seq_idx * q_stride + head_idx * HEAD_SIZE;
73+
h.parallel_for(sycl::range(num_seqs, num_heads), [=](sycl::item<2> item) {
74+
size_t seq_idx = item[0];
75+
size_t head_idx = item[1];
76+
int context_len = context_lens_acc[seq_idx];
77+
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
8578

86-
for (size_t block_idx = 0; block_idx < block_num; ++block_idx) {
87-
const int32_t physical_block_idx = block_tables_acc
88-
[block_idx + max_num_blocks_per_seq * seq_idx];
89-
size_t k_base_offset = physical_block_idx * kv_block_stride +
90-
kv_head_idx * kv_head_stride; // dim0,dim1
91-
float* __restrict__ head_block_logits = logits +
92-
seq_idx * logits_stride + head_idx * max_context_len_padded +
93-
block_idx * BLOCK_SIZE;
94-
for (int token_idx = 0; token_idx < BLOCK_SIZE; ++token_idx) {
95-
for (int i = 0; i < x; ++i) {
96-
head_block_logits[token_idx] +=
97-
(float)q_acc[i + x_idx * x + q_base_offset] *
98-
(float)k_cache_acc
99-
[i + token_idx * x + BLOCK_SIZE * x_idx * x +
100-
k_base_offset] *
101-
scale;
102-
}
79+
for (size_t block_idx = 0; block_idx < block_num; ++block_idx) {
80+
const int32_t physical_block_idx =
81+
block_tables_acc[block_idx + max_num_blocks_per_seq * seq_idx];
82+
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
83+
size_t q_base_offset = seq_idx * q_stride + head_idx * HEAD_SIZE;
84+
size_t k_base_offset = physical_block_idx * kv_block_stride +
85+
kv_head_idx * kv_head_stride; // dim0,dim1
86+
float* __restrict__ head_block_logits = logits +
87+
seq_idx * logits_stride + head_idx * max_context_len_padded +
88+
block_idx * BLOCK_SIZE;
89+
for (int x_idx = 0; x_idx < HEAD_SIZE / x; ++x_idx) {
90+
for (int token_idx = 0; token_idx < BLOCK_SIZE; ++token_idx) {
91+
for (int i = 0; i < x; ++i) {
92+
head_block_logits
93+
[token_idx] += (float)q_acc[i + x_idx * x + q_base_offset] *
94+
(float)k_cache_acc[i + token_idx * x +
95+
BLOCK_SIZE * x_idx * x + k_base_offset] *
96+
scale;
10397
}
10498
}
105-
});
99+
}
100+
}
101+
});
106102
});
107103
e1.wait();
108104

@@ -120,13 +116,15 @@ struct paged_attention_xpu_v1_impl_ {
120116
max_logit =
121117
max_logit >= head_logit_ptr[i] ? max_logit : head_logit_ptr[i];
122118
}
123-
float sum = 0;
119+
float sum = 0.f;
124120
for (int i = 0; i < context_len; ++i) {
125-
head_logit_ptr[i] = sycl::exp(head_logit_ptr[i] - max_logit);
126-
sum += head_logit_ptr[i];
121+
float val = sycl::exp<float>(head_logit_ptr[i] - max_logit);
122+
head_logit_ptr[i] = val;
123+
sum += val;
127124
}
125+
const float inv_sum = 1.f / (sum + 1e-6f);
128126
for (int i = 0; i < context_len; ++i) {
129-
head_logit_ptr[i] /= sum;
127+
head_logit_ptr[i] *= inv_sum;
130128
}
131129
int remaining_seq_upper = block_num * BLOCK_SIZE;
132130
for (int i = context_len; i < remaining_seq_upper; ++i) {
@@ -164,8 +162,8 @@ struct paged_attention_xpu_v1_impl_ {
164162
BLOCK_SIZE * head_part_idx * 16;
165163
size_t out_base_offset = seq_idx * num_heads * HEAD_SIZE +
166164
head_idx * HEAD_SIZE + head_part_idx * 16;
167-
for (int j = 0; j < BLOCK_SIZE; ++j) {
168-
for (int i = 0; i < 16; ++i) {
165+
for (int i = 0; i < 16; ++i) {
166+
for (int j = 0; j < BLOCK_SIZE; ++j) {
169167
output_acc[i + out_base_offset] +=
170168
(scalar_sycl_t)(prob_vec_ptr[j] * (float)v_cache_acc[j + i * BLOCK_SIZE + v_base_offset]);
171169
}

vllm/model_executor/layers/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def forward(
8181
Returns:
8282
shape = [batch_size, seq_len, num_heads * head_size]
8383
"""
84+
if query.is_xpu:
85+
torch.xpu.synchronize()
8486
batch_size, seq_len, hidden_size = query.shape
8587
# Reshape the query, key, and value tensors.
8688
query = query.view(-1, self.num_heads, self.head_size)

vllm/model_executor/layers/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _sample(
385385
# Counterintiutively, having two loops here is actually faster.
386386
# The first loop can run without waiting on GPU<->CPU sync.
387387
for sampling_type in SamplingType:
388-
sample_indices = categorized_sample_indices[sampling_type].to(torch.int64)
388+
sample_indices = categorized_sample_indices[sampling_type]
389389
num_tokens = len(sample_indices)
390390
if num_tokens == 0:
391391
continue
@@ -402,7 +402,7 @@ def _sample(
402402
if is_prompt:
403403
_, sampling_params = seq_group
404404
max_best_of = max(max_best_of, sampling_params.best_of)
405-
multinomial_samples = _multinomial(probs[sample_indices],
405+
multinomial_samples = _multinomial(probs.cpu()[sample_indices.cpu()],
406406
max_best_of)
407407
elif sampling_type == SamplingType.BEAM:
408408
beam_search_logprobs = logprobs[sample_indices]

vllm/model_executor/sampling_metadata.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
146146
dtype: torch.dtype) -> "SamplingTensors":
147147
# Note that the performance will be very bad without
148148
# pinned memory.
149-
if device == torch.device('xpu:0'): # FIXME: remove index?
149+
if device.type == 'xpu': # FIXME: remove index?
150150
pin_memory = False
151151
else:
152152
pin_memory = not in_wsl()
@@ -215,19 +215,21 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
215215
dtype=torch.long,
216216
pin_memory=pin_memory,
217217
)
218+
if device.type == 'xpu':
219+
torch.xpu.synchronize()
218220
# Because the memory is pinned, we can do non-blocking
219221
# transfer to device.
220222
return cls(
221-
temperatures=temperatures_t.to(device=device, non_blocking=True),
222-
top_ps=top_ps_t.to(device=device, non_blocking=True),
223-
top_ks=top_ks_t.to(device=device, non_blocking=True),
224-
min_ps=min_ps_t.to(device=device, non_blocking=True),
223+
temperatures=temperatures_t.to(device=device, non_blocking=pin_memory),
224+
top_ps=top_ps_t.to(device=device, non_blocking=pin_memory),
225+
top_ks=top_ks_t.to(device=device, non_blocking=pin_memory),
226+
min_ps=min_ps_t.to(device=device, non_blocking=pin_memory),
225227
presence_penalties=presence_penalties_t.to(device=device,
226-
non_blocking=True),
228+
non_blocking=pin_memory),
227229
frequency_penalties=frequency_penalties_t.to(device=device,
228-
non_blocking=True),
230+
non_blocking=pin_memory),
229231
repetition_penalties=repetition_penalties_t.to(device=device,
230-
non_blocking=True),
231-
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
232-
output_tokens=output_tensor.to(device=device, non_blocking=True),
232+
non_blocking=pin_memory),
233+
prompt_tokens=prompt_tensor.to(device=device, non_blocking=pin_memory),
234+
output_tokens=output_tensor.to(device=device, non_blocking=pin_memory),
233235
)

vllm/worker/worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,16 @@ def profile_num_available_blocks(
113113
total_xpu_memory = get_xpu_memory()
114114
cache_block_size = CacheEngine.get_cache_block_size(
115115
block_size, self.model_config, self.parallel_config)
116+
print("peak memory: " + str(peak_memory) + " total memory: " + str(total_xpu_memory) +
117+
"cache block size: " + str(cache_block_size))
118+
116119
num_xpu_blocks = int(
117120
(total_xpu_memory * gpu_memory_utilization - peak_memory) //
118121
cache_block_size)
119122
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
120123
num_xpu_blocks = max(num_xpu_blocks, 0)
121124
num_cpu_blocks = max(num_cpu_blocks, 0)
122-
torch.cuda.empty_cache()
125+
torch.xpu.empty_cache()
123126

124127
return 0, num_cpu_blocks, num_xpu_blocks
125128

0 commit comments

Comments
 (0)