Skip to content

Commit edd27de

Browse files
author
Varun
committed
fp8 quant int-overflow changes
Signed-off-by: Varun <vsundarr@redhat.com>
1 parent 8ac9fec commit edd27de

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,13 @@ def _per_token_group_quant_fp8(
234234
row = g_id // groups_per_row
235235
row_g_id = g_id % groups_per_row
236236

237-
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
238-
y_q_ptr += g_id * group_size
237+
# Ensure offset calculations use int64 to prevent overflow
238+
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
239+
group_size)
240+
y_ptr += y_ptr_offset
241+
242+
y_q_ptr_offset = g_id.to(tl.int64) * group_size
243+
y_q_ptr += y_q_ptr_offset
239244
y_s_ptr += g_id
240245

241246
cols = tl.arange(0, BLOCK) # N <= BLOCK
@@ -282,15 +287,23 @@ def _per_token_group_quant_fp8_colmajor(
282287
row = g_id // groups_per_row
283288
row_g_id = g_id % groups_per_row
284289

285-
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
286-
y_q_ptr += g_id * group_size
290+
# Ensure offset calculations use int64 to prevent overflow
291+
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
292+
group_size)
293+
y_ptr += y_ptr_offset
294+
295+
y_q_ptr_offset = g_id.to(tl.int64) * group_size
296+
y_q_ptr += y_q_ptr_offset
287297

288298
# Convert g_id the flattened block coordinate to 2D so we can index
289299
# into the output y_scales matrix
290300
blocks_per_row = y_num_columns // group_size
291301
scale_col = g_id % blocks_per_row
292302
scale_row = g_id // blocks_per_row
293-
y_s_ptr += scale_col * y_s_col_stride + scale_row
303+
# Ensure offset calculation uses int64 for y_s_ptr
304+
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
305+
tl.int64)
306+
y_s_ptr += y_s_ptr_offset
294307

295308
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
296309
mask = cols < group_size

0 commit comments

Comments
 (0)