@@ -234,8 +234,13 @@ def _per_token_group_quant_fp8(
234
234
row = g_id // groups_per_row
235
235
row_g_id = g_id % groups_per_row
236
236
237
- y_ptr += (row * y_row_stride ) + (row_g_id * group_size )
238
- y_q_ptr += g_id * group_size
237
+ # Ensure offset calculations use int64 to prevent overflow
238
+ y_ptr_offset = (row .to (tl .int64 ) * y_row_stride ) + (row_g_id .to (tl .int64 ) *
239
+ group_size )
240
+ y_ptr += y_ptr_offset
241
+
242
+ y_q_ptr_offset = g_id .to (tl .int64 ) * group_size
243
+ y_q_ptr += y_q_ptr_offset
239
244
y_s_ptr += g_id
240
245
241
246
cols = tl .arange (0 , BLOCK ) # N <= BLOCK
@@ -282,15 +287,23 @@ def _per_token_group_quant_fp8_colmajor(
282
287
row = g_id // groups_per_row
283
288
row_g_id = g_id % groups_per_row
284
289
285
- y_ptr += (row * y_row_stride ) + (row_g_id * group_size )
286
- y_q_ptr += g_id * group_size
290
+ # Ensure offset calculations use int64 to prevent overflow
291
+ y_ptr_offset = (row .to (tl .int64 ) * y_row_stride ) + (row_g_id .to (tl .int64 ) *
292
+ group_size )
293
+ y_ptr += y_ptr_offset
294
+
295
+ y_q_ptr_offset = g_id .to (tl .int64 ) * group_size
296
+ y_q_ptr += y_q_ptr_offset
287
297
288
298
# Convert g_id the flattened block coordinate to 2D so we can index
289
299
# into the output y_scales matrix
290
300
blocks_per_row = y_num_columns // group_size
291
301
scale_col = g_id % blocks_per_row
292
302
scale_row = g_id // blocks_per_row
293
- y_s_ptr += scale_col * y_s_col_stride + scale_row
303
+ # Ensure offset calculation uses int64 for y_s_ptr
304
+ y_s_ptr_offset = (scale_col .to (tl .int64 ) * y_s_col_stride ) + scale_row .to (
305
+ tl .int64 )
306
+ y_s_ptr += y_s_ptr_offset
294
307
295
308
cols = tl .arange (0 , BLOCK ) # group_size <= BLOCK
296
309
mask = cols < group_size
0 commit comments