Skip to content

Commit ad30615

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Use better exponent rounding in Triton MX4 quantize kernel
Summary: As noted in [this doc](https://docs.google.com/document/d/156Du0hBRH6umG_i-OrYC574XhpQMUU5SJYG0RTS2tTg/edit#heading=h.akfcp7xpg8cr), using a ceiling round for scale calculation does a better job of not truncating some mantissa bits. This diff switches triton's floor rounding to ceil rounding. Note that currently mx4_test doesnt pass as the cuda kernel now has different behavior than triton. Once we rebase this diff onto a similar change to the cuda kernel, we should see exact matching outputs again. Reviewed By: jianyuh Differential Revision: D59527463
1 parent b903979 commit ad30615

File tree

2 files changed

+28
-23
lines changed

2 files changed

+28
-23
lines changed

fbgemm_gpu/fbgemm_gpu/triton/quantize.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,15 @@ def _kernel_quantize_mx4(
7777
MAX_FP32_MANTISSA_BITS: tl.constexpr = 24 # type: ignore[Incompatible variable type]
7878
IMPLIED_1_BIT: tl.constexpr = 1 << 23 # type: ignore[Incompatible variable type]
7979
OVERFLOW_THRESHOLD: tl.constexpr = 4 # type: ignore[Incompatible variable type]
80+
FP32_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
8081

8182
# First we need to compute shared exponent.
8283
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
8384
# Load a block of values.
8485
a = tl.load(
8586
A + pid * stride_am + k_offset * stride_ak,
8687
mask=k_offset < K,
87-
other=-float("inf"),
88+
other=0,
8889
)
8990

9091
# Scaling step
@@ -94,13 +95,15 @@ def _kernel_quantize_mx4(
9495
a_groups = tl.reshape(a, [BLOCK_SIZE // GROUP_SIZE, GROUP_SIZE])
9596
# Compute the shared exponent of each group.
9697
group_max = tl.max(tl.abs(a_groups), axis=1)
97-
# Convert max to exponent via bit operations.
98-
group_exp = group_max.to(tl.int32, bitcast=True) & FP32_EXP_MASK
99-
group_exp = group_exp >> FP32_EXP_OFFSET
98+
# Prevent infinite values in log.
99+
group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max)
100+
# Convert max to exponent via direct log computation and ceiling
101+
# rounding to minimize errors.
102+
group_exp = tl.ceil(tl.log2(group_max))
100103
# Subtract largest exponent in target datatype and remove bias.
101-
group_exp = group_exp - 2 - FP32_EXP_BIAS
102-
# Clamp to valid int8 range.
103-
group_exp = tl.maximum(group_exp, -127)
104+
group_exp = group_exp - 2
105+
# Make sure exponent is in valid range.
106+
group_exp = tl.clamp(group_exp, -127, 125)
104107

105108
# Next we scale A in preparation for quantization.
106109
scale = tl.exp2(group_exp.to(tl.float64)).to(tl.float32)
@@ -113,10 +116,9 @@ def _kernel_quantize_mx4(
113116

114117
# We're done with group_exp now so we can write it out.
115118
# We readd fp32_exp_bias for compatibility with cuda dequant.
116-
group_exp = group_exp.to(tl.int8)
117119
tl.store(
118120
shared_exp + pid * stride_exp_m + stride_exp_k * group_offset,
119-
group_exp + FP32_EXP_BIAS,
121+
(group_exp + FP32_EXP_BIAS).to(tl.int8),
120122
mask=group_offset < K // GROUP_SIZE,
121123
)
122124

@@ -228,11 +230,12 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
228230
# We do this by finding the power of two that is closest to
229231
# the sqrt of the number of elements.
230232
num_threads = int(2 ** round(math.log2(math.sqrt(a.numel()))))
231-
# Make sure that num_threads is a multiple of group_size.
232-
num_threads = (num_threads // group_size) * group_size
233-
if num_threads == 0:
234-
num_threads = a.numel() // group_size
235-
a = a.view(num_threads, -1)
233+
# Make sure that the number of elements per row is a multiple of group_size.
234+
K = a.numel() // num_threads
235+
K = (K // group_size) * group_size
236+
if K == 0:
237+
K = group_size
238+
a = a.view(-1, K)
236239
M, K = a.shape
237240
# If K is less than group_size, we compute a single group per row.
238241
if K < group_size:
@@ -382,19 +385,21 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor
382385
# View a as 2D for simplicity.
383386
orig_shape = a.shape
384387
# Unravel packed inputs from shared exponents.
385-
a = a.view(-1, (group_size // 2) + 1)
388+
packed_group_size = group_size // 2
389+
a = a.view(-1, packed_group_size + 1)
386390
packed_input = a[:, :-1]
387391
shared_exp = a[:, -1:]
388392
# Find a shape that distributes work evenly over threads.
389393
# We do this by finding the power of two that is closest to
390394
# the sqrt of the number of elements.
391395
num_threads = int(2 ** round(math.log2(math.sqrt(packed_input.numel()))))
392-
# Make sure that num_threads is a multiple of group_size.
393-
num_threads = (num_threads // group_size) * group_size
394-
if num_threads == 0:
395-
num_threads = packed_input.numel() // group_size
396-
packed_input = packed_input.reshape(num_threads, -1)
397-
shared_exp = shared_exp.reshape(num_threads, -1)
396+
# Make sure that the number of elements per row is a multiple of packed group_size.
397+
K = packed_input.numel() // num_threads
398+
K = (K // packed_group_size) * packed_group_size
399+
if K == 0:
400+
K = packed_group_size
401+
packed_input = packed_input.reshape(-1, K)
402+
shared_exp = shared_exp.reshape(-1, K // packed_group_size)
398403
M, K_2 = packed_input.shape
399404

400405
# Use a lookup table to convert

fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ def py_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
4646
# Convert max into an intger exponent.
4747
# Note this can be more efficient by just shifting and masking exp bits.
4848
# We can even use those directly.
49-
shared_exp = torch.floor(torch.log2(shared_exp))
49+
shared_exp = torch.ceil(torch.log2(shared_exp))
5050
# Offset exponent by largest exponent in target datatype.
5151
shared_exp = shared_exp - 2
5252
# Restrict to range expressible as int8.
53-
shared_exp = torch.clamp(shared_exp, min=-127, max=127)
53+
shared_exp = torch.clamp(shared_exp, min=-127, max=125)
5454
# Convert exponent to scale and apply to input.
5555
# Need to do this calculation on cpu for accuracy.
5656
_shared_exp = shared_exp.cpu()

0 commit comments

Comments
 (0)