Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _dequantize_tensor(
comm_precision: SparseType,
ctx: Optional[QuantizationContext] = None,
is_fwd: bool = True,
fp8_output_dtype: Optional[SparseType] = None,
) -> torch.Tensor:
if comm_precision == SparseType.FP32:
assert quantized_tensor.dtype == torch.float
Expand All @@ -137,8 +138,14 @@ def _dequantize_tensor(
if ctx is not None and ctx.row_dim > 0:
row_dim_quant = ctx.row_dim_quant
quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
# use provided fp8_output_dtype or default to FP32 (0)
output_dtype_int = (
fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0
)
dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
quantized_tensor_2d, is_fwd
quantized_tensor_2d,
is_fwd,
output_dtype_int,
)
return dequant_tensor.view(-1)
else:
Expand Down Expand Up @@ -168,6 +175,7 @@ def __init__(
row_dim: Optional[int] = None,
is_fwd: bool = True,
rounding_mode: Optional[RoundingMode] = None,
fp8_output_dtype: Optional[SparseType] = None,
) -> None:
if loss_scale is not None:
if comm_precision not in [SparseType.FP16, SparseType.BF16]:
Expand All @@ -185,6 +193,7 @@ def __init__(
self._is_fwd = is_fwd
self._row_dim: int = -1 if row_dim is None else row_dim
self._rounding_mode: Optional[RoundingMode] = rounding_mode
self._fp8_output_dtype: Optional[SparseType] = fp8_output_dtype
if self._comm_precision == SparseType.MX4:
self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
self._rounding_mode = (
Expand Down Expand Up @@ -216,7 +225,11 @@ def decode(
f"## decoder {self._comm_precision} {self._loss_scale} ##"
):
dequantized_tensor = _dequantize_tensor(
input_tensor, self._comm_precision, ctx, self._is_fwd
input_tensor,
self._comm_precision,
ctx,
self._is_fwd,
fp8_output_dtype=self._fp8_output_dtype,
)
return dequantized_tensor

Expand Down
Loading