Introduce FP8 row-based quantization #194
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
Low precision gradient synchronization has emerged from research as a viable strategy to reduce peak bandwidth between ranks and speed up training without* loss in model quality. There has been a lot of research literature focusing on gradient compression during training (links to quantization with rescaling, product quantization, quantization with error feedback) with quantization based on rescaling being appealing due to low compute and memory overhead (product quantization requires expensive nearest neighbor lookup in the code book, quantization with error feedback requires error accumulation equal in size to the gradients).
Specifically, we would like to leverage gradient synchronization in low precision to enable cross-region training (based on streaming DiLoCo recipe) where reducing peak cross-region traffic and worker synchronization latency is of utmost importance. The research has proven that gradient synchronization can be done in FP4 precision. Before investing into FP4 format we would like to prove that a high performance implementation is possible for partially supported FP8 format and on success extend with FP4 support.
Problem Definition
Gradient synchronization is done via an all-reduce collective. At the time of writing this NCCL(x) backends do not support all-reduce collective in FP8 precision (not to mention FP4). There has been a change proposed to enable FP8 support in NCCL, however it has a significant drawback of performing reduction in the same precision as the input buffers leading to accelerated loss in precision.
Loss in precision during reduction must be decreased to avoid (reduce) model quality degradation. This brings us to the following problem definition. All-reduce collective must communicate in low precision but reduce in high precision. More specifically, we would like to communicate gradients in FP8 (or later in FP4) precision but accumulate in FP32 precision.
Rescaling Quantization
Before proceeding to the algorithm itself, let’s recap how quantization based on rescaling works. For a given tensor T and target precision format F we compute a scaling factor as
SF = F.MAX / max(abs(T))
In order to quantize we multiply the original tensor with the scaling factor, to dequantize we divide the quantized tensor (Tq) with the scaling factor.
Tq = T * SF
Td = Tq / SF
Note that precision loss due to quantization (|Td - Tq|) is higher if quantized values differ significantly (in absolute terms) leading to larger scaling factors being chosen to accommodate all values. To mitigate precision loss quantization is performed per row (or per block, however in this case we choose per-row quantization) expecting scaling factor choice to be more optimal due to lower value differences within rows.
Quantized Reduction
In order to reduce quantized tensors they must be de-quantized first into higher precision, reduced in higher precision and then the reduction result quantized back into lower precision. Quantized reduction requires reducing rank to know scaling factors for all tensors being reduced. There are two options of enabling this knowledge: (1) choosing a well known scaling factor without communication; (2) communicating dynamically chosen scaling factors.
The first option may leverage the history of chosen scaling factors from previous steps to choose scaling factors for the current step. While it avoids communication it may be suboptimal either because the scaling factor is too conservative (too large, leading to precision loss) or too aggressive (too small, leading to value clamping). To keep scaling factor choice optimal we choose to communicate scaling factors on every step.
The second option has also two possibilities: (1) communicate scaling factors alone, agree on common scaling factors and then quantize; (2) communicate scaling factors with quantized values. The first approach leads to suboptimal scaling factors choice if values differ significantly and also require separate communication. To reduce precision loss we communicate scaling factors with quantized values.
All-Reduce with Quantization
Now we have all the pieces and we can put together the algorithm for all-reduce with FP8 quantization. The algorithm can be summarized as follows: each rank (1) quantizes input tensors per row; (2) collects rows from peer ranks using all-to-all so they can be reduced locally; (3) de-quantizes, reduces in higher precision, quantizes result; (4) collects partial results from peer ranks using all-gather.
This algorithm will be introduced in the subsequent PR.