Skip to content

Commit 17248b7

Browse files
Armand Sauzayfacebook-github-bot
authored andcommitted
Enable specifying output dtype for fp8 quantized communication (#3568)
Summary: X-link: pytorch/FBGEMM#5154 X-link: facebookresearch/FBGEMM#2154 Adding fp8_output_dtype parameter to the qcomms config allowing fp8 to dequantize in different float formats as opposed to only FP32 Reviewed By: spcyppt Differential Revision: D86890315
1 parent 217889e commit 17248b7

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

torchrec/distributed/fbgemm_qcomm_codec.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
from dataclasses import dataclass
1515
from enum import Enum, unique
16-
from typing import cast, Dict, List, Optional
16+
from typing import Any, cast, Dict, List, Optional
1717

1818
import torch
1919

@@ -69,6 +69,7 @@ class QCommsConfig:
6969
fp8_quantize_dim: Optional[int] = None
7070
fp8_quantize_dim_bwd: Optional[int] = None
7171
fp8_bwd_uses_143: Optional[bool] = False
72+
fp8_output_dtype: Optional[SparseType] = None
7273
mx4_quantize_dim: Optional[int] = None
7374
mx4_quantize_dim_bwd: Optional[int] = None
7475
mx4_rounding_mode: Optional[RoundingMode] = None
@@ -127,37 +128,46 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
127128
elif qcomms_config.forward_precision == CommType.MX4:
128129
row_dim = qcomms_config.mx4_quantize_dim
129130
rounding_mode = qcomms_config.mx4_rounding_mode
131+
132+
forward_kwargs: Dict[str, Any] = {
133+
"comm_precision": comm_type_to_sparse_type(qcomms_config.forward_precision),
134+
"loss_scale": qcomms_config.forward_loss_scale,
135+
"is_fwd": True,
136+
"row_dim": row_dim,
137+
"rounding_mode": rounding_mode,
138+
}
139+
# kwargs approach for bwd compatibility (D86890315)
140+
if qcomms_config.fp8_output_dtype is not None:
141+
forward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype
142+
130143
codecs.forward = cast(
131144
QuantizedCommCodec[QuantizationContext],
132-
FbgemmQuantizedCommCodec(
133-
comm_precision=comm_type_to_sparse_type(
134-
qcomms_config.forward_precision
135-
),
136-
loss_scale=qcomms_config.forward_loss_scale,
137-
is_fwd=True,
138-
row_dim=row_dim,
139-
rounding_mode=rounding_mode,
140-
),
145+
FbgemmQuantizedCommCodec(**forward_kwargs),
141146
)
142147
row_dim_bwd = None
143148
if qcomms_config.backward_precision == CommType.FP8:
144149
row_dim_bwd = qcomms_config.fp8_quantize_dim_bwd
145150
elif qcomms_config.backward_precision == CommType.MX4:
146151
row_dim_bwd = qcomms_config.mx4_quantize_dim_bwd
152+
backward_kwargs: Dict[str, Any] = {
153+
"comm_precision": comm_type_to_sparse_type(
154+
qcomms_config.backward_precision
155+
),
156+
"loss_scale": qcomms_config.backward_loss_scale,
157+
"is_fwd": (
158+
True if qcomms_config.fp8_bwd_uses_143 else False
159+
), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3
160+
# if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2
161+
"row_dim": row_dim_bwd,
162+
"rounding_mode": rounding_mode,
163+
}
164+
# kwargs approach for bwd compatibility (D86890315)
165+
if qcomms_config.fp8_output_dtype is not None:
166+
backward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype
167+
147168
codecs.backward = cast(
148169
QuantizedCommCodec[QuantizationContext],
149-
FbgemmQuantizedCommCodec(
150-
comm_precision=comm_type_to_sparse_type(
151-
qcomms_config.backward_precision
152-
),
153-
loss_scale=qcomms_config.backward_loss_scale,
154-
is_fwd=(
155-
True if qcomms_config.fp8_bwd_uses_143 else False
156-
), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3
157-
# if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2
158-
row_dim=row_dim_bwd,
159-
rounding_mode=rounding_mode,
160-
),
170+
FbgemmQuantizedCommCodec(**backward_kwargs),
161171
)
162172
return codecs
163173

0 commit comments

Comments
 (0)