Skip to content

Introduce FP8 row-based quantization #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2025
Merged

Introduce FP8 row-based quantization #194

merged 3 commits into from
May 15, 2025

Conversation

dzmitry-huba
Copy link
Contributor

Motivation

Low precision gradient synchronization has emerged from research as a viable strategy to reduce peak bandwidth between ranks and speed up training without* loss in model quality. There has been a lot of research literature focusing on gradient compression during training (links to quantization with rescaling, product quantization, quantization with error feedback) with quantization based on rescaling being appealing due to low compute and memory overhead (product quantization requires expensive nearest neighbor lookup in the code book, quantization with error feedback requires error accumulation equal in size to the gradients).

Specifically, we would like to leverage gradient synchronization in low precision to enable cross-region training (based on streaming DiLoCo recipe) where reducing peak cross-region traffic and worker synchronization latency is of utmost importance. The research has proven that gradient synchronization can be done in FP4 precision. Before investing into FP4 format we would like to prove that a high performance implementation is possible for partially supported FP8 format and on success extend with FP4 support.

Problem Definition

Gradient synchronization is done via an all-reduce collective. At the time of writing this NCCL(x) backends do not support all-reduce collective in FP8 precision (not to mention FP4). There has been a change proposed to enable FP8 support in NCCL, however it has a significant drawback of performing reduction in the same precision as the input buffers leading to accelerated loss in precision.

Loss in precision during reduction must be decreased to avoid (reduce) model quality degradation. This brings us to the following problem definition. All-reduce collective must communicate in low precision but reduce in high precision. More specifically, we would like to communicate gradients in FP8 (or later in FP4) precision but accumulate in FP32 precision.

Rescaling Quantization

Before proceeding to the algorithm itself, let’s recap how quantization based on rescaling works. For a given tensor T and target precision format F we compute a scaling factor as

SF = F.MAX / max(abs(T))

In order to quantize we multiply the original tensor with the scaling factor, to dequantize we divide the quantized tensor (Tq) with the scaling factor.

Tq = T * SF
Td = Tq / SF

Note that precision loss due to quantization (|Td - Tq|) is higher if quantized values differ significantly (in absolute terms) leading to larger scaling factors being chosen to accommodate all values. To mitigate precision loss quantization is performed per row (or per block, however in this case we choose per-row quantization) expecting scaling factor choice to be more optimal due to lower value differences within rows.

Quantized Reduction

In order to reduce quantized tensors they must be de-quantized first into higher precision, reduced in higher precision and then the reduction result quantized back into lower precision. Quantized reduction requires reducing rank to know scaling factors for all tensors being reduced. There are two options of enabling this knowledge: (1) choosing a well known scaling factor without communication; (2) communicating dynamically chosen scaling factors.

The first option may leverage the history of chosen scaling factors from previous steps to choose scaling factors for the current step. While it avoids communication it may be suboptimal either because the scaling factor is too conservative (too large, leading to precision loss) or too aggressive (too small, leading to value clamping). To keep scaling factor choice optimal we choose to communicate scaling factors on every step.
The second option has also two possibilities: (1) communicate scaling factors alone, agree on common scaling factors and then quantize; (2) communicate scaling factors with quantized values. The first approach leads to suboptimal scaling factors choice if values differ significantly and also require separate communication. To reduce precision loss we communicate scaling factors with quantized values.

All-Reduce with Quantization

Now we have all the pieces and we can put together the algorithm for all-reduce with FP8 quantization. The algorithm can be summarized as follows: each rank (1) quantizes input tensors per row; (2) collects rows from peer ranks using all-to-all so they can be reduced locally; (3) de-quantizes, reduces in higher precision, quantizes result; (4) collects partial results from peer ranks using all-gather.

This algorithm will be introduced in the subsequent PR.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 15, 2025
@dzmitry-huba dzmitry-huba requested review from d4l3k and H-Huang May 15, 2025 21:36
@dzmitry-huba dzmitry-huba marked this pull request as ready for review May 15, 2025 21:36
Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Looking forward to the allreduce to be added

@dzmitry-huba dzmitry-huba merged commit b84c5a6 into main May 15, 2025
7 of 8 checks passed
@d4l3k d4l3k deleted the huba/fp8_allreduce branch May 15, 2025 23:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants