|
13 | 13 | import logging |
14 | 14 | from dataclasses import dataclass |
15 | 15 | from enum import Enum, unique |
16 | | -from typing import cast, Dict, List, Optional |
| 16 | +from typing import Any, cast, Dict, List, Optional |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 |
|
@@ -69,6 +69,7 @@ class QCommsConfig: |
69 | 69 | fp8_quantize_dim: Optional[int] = None |
70 | 70 | fp8_quantize_dim_bwd: Optional[int] = None |
71 | 71 | fp8_bwd_uses_143: Optional[bool] = False |
| 72 | + fp8_output_dtype: Optional[SparseType] = None |
72 | 73 | mx4_quantize_dim: Optional[int] = None |
73 | 74 | mx4_quantize_dim_bwd: Optional[int] = None |
74 | 75 | mx4_rounding_mode: Optional[RoundingMode] = None |
@@ -127,37 +128,46 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode |
127 | 128 | elif qcomms_config.forward_precision == CommType.MX4: |
128 | 129 | row_dim = qcomms_config.mx4_quantize_dim |
129 | 130 | rounding_mode = qcomms_config.mx4_rounding_mode |
| 131 | + |
| 132 | + forward_kwargs: Dict[str, Any] = { |
| 133 | + "comm_precision": comm_type_to_sparse_type(qcomms_config.forward_precision), |
| 134 | + "loss_scale": qcomms_config.forward_loss_scale, |
| 135 | + "is_fwd": True, |
| 136 | + "row_dim": row_dim, |
| 137 | + "rounding_mode": rounding_mode, |
| 138 | + } |
| 139 | + # kwargs approach for bwd compatibility (D86890315) |
| 140 | + if qcomms_config.fp8_output_dtype is not None: |
| 141 | + forward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype |
| 142 | + |
130 | 143 | codecs.forward = cast( |
131 | 144 | QuantizedCommCodec[QuantizationContext], |
132 | | - FbgemmQuantizedCommCodec( |
133 | | - comm_precision=comm_type_to_sparse_type( |
134 | | - qcomms_config.forward_precision |
135 | | - ), |
136 | | - loss_scale=qcomms_config.forward_loss_scale, |
137 | | - is_fwd=True, |
138 | | - row_dim=row_dim, |
139 | | - rounding_mode=rounding_mode, |
140 | | - ), |
| 145 | + FbgemmQuantizedCommCodec(**forward_kwargs), |
141 | 146 | ) |
142 | 147 | row_dim_bwd = None |
143 | 148 | if qcomms_config.backward_precision == CommType.FP8: |
144 | 149 | row_dim_bwd = qcomms_config.fp8_quantize_dim_bwd |
145 | 150 | elif qcomms_config.backward_precision == CommType.MX4: |
146 | 151 | row_dim_bwd = qcomms_config.mx4_quantize_dim_bwd |
| 152 | + backward_kwargs: Dict[str, Any] = { |
| 153 | + "comm_precision": comm_type_to_sparse_type( |
| 154 | + qcomms_config.backward_precision |
| 155 | + ), |
| 156 | + "loss_scale": qcomms_config.backward_loss_scale, |
| 157 | + "is_fwd": ( |
| 158 | + True if qcomms_config.fp8_bwd_uses_143 else False |
| 159 | + ), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3 |
| 160 | + # if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2 |
| 161 | + "row_dim": row_dim_bwd, |
| 162 | + "rounding_mode": rounding_mode, |
| 163 | + } |
| 164 | + # kwargs approach for bwd compatibility (D86890315) |
| 165 | + if qcomms_config.fp8_output_dtype is not None: |
| 166 | + backward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype |
| 167 | + |
147 | 168 | codecs.backward = cast( |
148 | 169 | QuantizedCommCodec[QuantizationContext], |
149 | | - FbgemmQuantizedCommCodec( |
150 | | - comm_precision=comm_type_to_sparse_type( |
151 | | - qcomms_config.backward_precision |
152 | | - ), |
153 | | - loss_scale=qcomms_config.backward_loss_scale, |
154 | | - is_fwd=( |
155 | | - True if qcomms_config.fp8_bwd_uses_143 else False |
156 | | - ), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3 |
157 | | - # if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2 |
158 | | - row_dim=row_dim_bwd, |
159 | | - rounding_mode=rounding_mode, |
160 | | - ), |
| 170 | + FbgemmQuantizedCommCodec(**backward_kwargs), |
161 | 171 | ) |
162 | 172 | return codecs |
163 | 173 |
|
|
0 commit comments