@@ -77,14 +77,15 @@ def _kernel_quantize_mx4(
77
77
MAX_FP32_MANTISSA_BITS : tl .constexpr = 24 # type: ignore[Incompatible variable type]
78
78
IMPLIED_1_BIT : tl .constexpr = 1 << 23 # type: ignore[Incompatible variable type]
79
79
OVERFLOW_THRESHOLD : tl .constexpr = 4 # type: ignore[Incompatible variable type]
80
+ FP32_MIN_NORMAL : tl .constexpr = 2 ** (- 126 ) # type: ignore[Incompatible variable type]
80
81
81
82
# First we need to compute shared exponent.
82
83
for _k in range (0 , tl .cdiv (K , BLOCK_SIZE )):
83
84
# Load a block of values.
84
85
a = tl .load (
85
86
A + pid * stride_am + k_offset * stride_ak ,
86
87
mask = k_offset < K ,
87
- other = - float ( "inf" ) ,
88
+ other = 0 ,
88
89
)
89
90
90
91
# Scaling step
@@ -94,13 +95,15 @@ def _kernel_quantize_mx4(
94
95
a_groups = tl .reshape (a , [BLOCK_SIZE // GROUP_SIZE , GROUP_SIZE ])
95
96
# Compute the shared exponent of each group.
96
97
group_max = tl .max (tl .abs (a_groups ), axis = 1 )
97
- # Convert max to exponent via bit operations.
98
- group_exp = group_max .to (tl .int32 , bitcast = True ) & FP32_EXP_MASK
99
- group_exp = group_exp >> FP32_EXP_OFFSET
98
+ # Prevent infinite values in log.
99
+ group_max = tl .where (group_max == 0 , FP32_MIN_NORMAL , group_max )
100
+ # Convert max to exponent via direct log computation and ceiling
101
+ # rounding to minimize errors.
102
+ group_exp = tl .ceil (tl .log2 (group_max ))
100
103
# Subtract largest exponent in target datatype and remove bias.
101
- group_exp = group_exp - 2 - FP32_EXP_BIAS
102
- # Clamp to valid int8 range.
103
- group_exp = tl .maximum (group_exp , - 127 )
104
+ group_exp = group_exp - 2
105
+ # Make sure exponent is in valid range.
106
+ group_exp = tl .clamp (group_exp , - 127 , 125 )
104
107
105
108
# Next we scale A in preparation for quantization.
106
109
scale = tl .exp2 (group_exp .to (tl .float64 )).to (tl .float32 )
@@ -113,10 +116,9 @@ def _kernel_quantize_mx4(
113
116
114
117
# We're done with group_exp now so we can write it out.
115
118
# We readd fp32_exp_bias for compatibility with cuda dequant.
116
- group_exp = group_exp .to (tl .int8 )
117
119
tl .store (
118
120
shared_exp + pid * stride_exp_m + stride_exp_k * group_offset ,
119
- group_exp + FP32_EXP_BIAS ,
121
+ ( group_exp + FP32_EXP_BIAS ). to ( tl . int8 ) ,
120
122
mask = group_offset < K // GROUP_SIZE ,
121
123
)
122
124
@@ -228,11 +230,12 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
228
230
# We do this by finding the power of two that is closest to
229
231
# the sqrt of the number of elements.
230
232
num_threads = int (2 ** round (math .log2 (math .sqrt (a .numel ()))))
231
- # Make sure that num_threads is a multiple of group_size.
232
- num_threads = (num_threads // group_size ) * group_size
233
- if num_threads == 0 :
234
- num_threads = a .numel () // group_size
235
- a = a .view (num_threads , - 1 )
233
+ # Make sure that the number of elements per row is a multiple of group_size.
234
+ K = a .numel () // num_threads
235
+ K = (K // group_size ) * group_size
236
+ if K == 0 :
237
+ K = group_size
238
+ a = a .view (- 1 , K )
236
239
M , K = a .shape
237
240
# If K is less than group_size, we compute a single group per row.
238
241
if K < group_size :
@@ -382,19 +385,21 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor
382
385
# View a as 2D for simplicity.
383
386
orig_shape = a .shape
384
387
# Unravel packed inputs from shared exponents.
385
- a = a .view (- 1 , (group_size // 2 ) + 1 )
388
+ packed_group_size = group_size // 2
389
+ a = a .view (- 1 , packed_group_size + 1 )
386
390
packed_input = a [:, :- 1 ]
387
391
shared_exp = a [:, - 1 :]
388
392
# Find a shape that distributes work evenly over threads.
389
393
# We do this by finding the power of two that is closest to
390
394
# the sqrt of the number of elements.
391
395
num_threads = int (2 ** round (math .log2 (math .sqrt (packed_input .numel ()))))
392
- # Make sure that num_threads is a multiple of group_size.
393
- num_threads = (num_threads // group_size ) * group_size
394
- if num_threads == 0 :
395
- num_threads = packed_input .numel () // group_size
396
- packed_input = packed_input .reshape (num_threads , - 1 )
397
- shared_exp = shared_exp .reshape (num_threads , - 1 )
396
+ # Make sure that the number of elements per row is a multiple of packed group_size.
397
+ K = packed_input .numel () // num_threads
398
+ K = (K // packed_group_size ) * packed_group_size
399
+ if K == 0 :
400
+ K = packed_group_size
401
+ packed_input = packed_input .reshape (- 1 , K )
402
+ shared_exp = shared_exp .reshape (- 1 , K // packed_group_size )
398
403
M , K_2 = packed_input .shape
399
404
400
405
# Use a lookup table to convert
0 commit comments