Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions torchrec/distributed/fbgemm_qcomm_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import logging
from dataclasses import dataclass
from enum import Enum, unique
from typing import cast, Dict, List, Optional
from typing import Any, cast, Dict, List, Optional

import torch

Expand Down Expand Up @@ -69,6 +69,7 @@ class QCommsConfig:
fp8_quantize_dim: Optional[int] = None
fp8_quantize_dim_bwd: Optional[int] = None
fp8_bwd_uses_143: Optional[bool] = False
fp8_output_dtype: Optional[SparseType] = None
mx4_quantize_dim: Optional[int] = None
mx4_quantize_dim_bwd: Optional[int] = None
mx4_rounding_mode: Optional[RoundingMode] = None
Expand Down Expand Up @@ -127,37 +128,46 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
elif qcomms_config.forward_precision == CommType.MX4:
row_dim = qcomms_config.mx4_quantize_dim
rounding_mode = qcomms_config.mx4_rounding_mode

forward_kwargs: Dict[str, Any] = {
"comm_precision": comm_type_to_sparse_type(qcomms_config.forward_precision),
"loss_scale": qcomms_config.forward_loss_scale,
"is_fwd": True,
"row_dim": row_dim,
"rounding_mode": rounding_mode,
}
# kwargs approach for bwd compatibility (D86890315)
if qcomms_config.fp8_output_dtype is not None:
forward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype

codecs.forward = cast(
QuantizedCommCodec[QuantizationContext],
FbgemmQuantizedCommCodec(
comm_precision=comm_type_to_sparse_type(
qcomms_config.forward_precision
),
loss_scale=qcomms_config.forward_loss_scale,
is_fwd=True,
row_dim=row_dim,
rounding_mode=rounding_mode,
),
FbgemmQuantizedCommCodec(**forward_kwargs),
)
row_dim_bwd = None
if qcomms_config.backward_precision == CommType.FP8:
row_dim_bwd = qcomms_config.fp8_quantize_dim_bwd
elif qcomms_config.backward_precision == CommType.MX4:
row_dim_bwd = qcomms_config.mx4_quantize_dim_bwd
backward_kwargs: Dict[str, Any] = {
"comm_precision": comm_type_to_sparse_type(
qcomms_config.backward_precision
),
"loss_scale": qcomms_config.backward_loss_scale,
"is_fwd": (
True if qcomms_config.fp8_bwd_uses_143 else False
), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3
# if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2
"row_dim": row_dim_bwd,
"rounding_mode": rounding_mode,
}
# kwargs approach for bwd compatibility (D86890315)
if qcomms_config.fp8_output_dtype is not None:
backward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype

codecs.backward = cast(
QuantizedCommCodec[QuantizationContext],
FbgemmQuantizedCommCodec(
comm_precision=comm_type_to_sparse_type(
qcomms_config.backward_precision
),
loss_scale=qcomms_config.backward_loss_scale,
is_fwd=(
True if qcomms_config.fp8_bwd_uses_143 else False
), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3
# if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2
row_dim=row_dim_bwd,
rounding_mode=rounding_mode,
),
FbgemmQuantizedCommCodec(**backward_kwargs),
)
return codecs

Expand Down
Loading