Skip to content

Commit c9dddb2

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Switch to Triton FP8 Quantization in EMU1.6
Summary: For some reason, the cuda `quantize_fp8_per_row` kernel is very slow in EMU. Switching it to the functionally equivalent triton kernel yields excellent speedups from FP8. For eager mode, I'm seeing a 20% e2e speedup and still getting proper outputs. Eager: BF16: 19702.10ms FP8 Triton Quant: 16466.97ms Compiled: FP8 Native Quant: 14605.18ms FP8 Triton Quant: 16043.92ms BF16: 18030.98ms We see that quantizing in native pytorch helps quite a bit when torch.compile is used. I added the option to choose which quantization function is used and default to triton when torch.compile is off and native torch when torch.compile is on. This gives us the best performance in either case. Reviewed By: jiawenliu64 Differential Revision: D58167756
1 parent 003112c commit c9dddb2

File tree

2 files changed

+110
-66
lines changed

2 files changed

+110
-66
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _test_quantize_fp8_row(
5050
)
5151

5252
# Undo scaling.
53-
a_torch = a_fp8.base.to(torch.bfloat16)
53+
a_torch = a_fp8.to(torch.bfloat16)
5454
a_torch *= a_scale[:, None]
5555

5656
self.assertTrue(
@@ -110,7 +110,7 @@ def _test_quantize_fp8_block(
110110

111111
a_fp8, a_scale = quantize_fp8_block(a, BLOCK_M, BLOCK_K, scale_ub=scale_ub)
112112

113-
a_torch = a_fp8.base.to(torch.bfloat16)
113+
a_torch = a_fp8.to(torch.bfloat16)
114114

115115
# Undo scaling.
116116
for i in range(0, M, BLOCK_M):

0 commit comments

Comments
 (0)