Skip to content

Commit f1f413b

Browse files
Armand Sauzayfacebook-github-bot
authored andcommitted
Enable specifying output dtype for fp8 quantized communication
Summary: X-link: pytorch/FBGEMM#5154 X-link: facebookresearch/FBGEMM#2154 Adding fp8_output_dtype parameter to the qcomms config allowing fp8 to dequantize in different float formats as opposed to only FP32 Reviewed By: spcyppt Differential Revision: D86890315
1 parent 32e5431 commit f1f413b

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchrec/distributed/fbgemm_qcomm_codec.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class QCommsConfig:
6969
fp8_quantize_dim: Optional[int] = None
7070
fp8_quantize_dim_bwd: Optional[int] = None
7171
fp8_bwd_uses_143: Optional[bool] = False
72+
fp8_output_dtype: Optional[SparseType] = None
7273
mx4_quantize_dim: Optional[int] = None
7374
mx4_quantize_dim_bwd: Optional[int] = None
7475
mx4_rounding_mode: Optional[RoundingMode] = None
@@ -137,6 +138,7 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
137138
is_fwd=True,
138139
row_dim=row_dim,
139140
rounding_mode=rounding_mode,
141+
fp8_output_dtype=qcomms_config.fp8_output_dtype,
140142
),
141143
)
142144
row_dim_bwd = None
@@ -157,6 +159,7 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
157159
# if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2
158160
row_dim=row_dim_bwd,
159161
rounding_mode=rounding_mode,
162+
fp8_output_dtype=qcomms_config.fp8_output_dtype,
160163
),
161164
)
162165
return codecs

0 commit comments

Comments
 (0)