Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[RFC] Float8 Inference #314

Closed
Closed
@drisspg

Description

@drisspg

RFC: Float8 Inference

  • status: draft

Objective

We want to provide an easy mechanism to utilize FP8 in inference, and see both decreased memory usage and performance gains on hardware that supports native FP8 computation. We would like the API to require minimal model rewrites. We also want it to be configurable in such a way as to provide multiple levels of scaling granularity with their own accuracy/performance trade-offs. The solution should be composable with other inference components in the PyTorch ecosystem:

  • Export
  • AOTI
  • Dynamic Shapes

This solution is targeting server-side GPU inference. It is not currently focused on supporting edge or CPU inference.

Background

Float8 inference can be used to reduce memory usage and improve computational efficiency. By using FP8 instead of higher precision formats, we can achieve significant speedups and memory savings with minimal loss in accuracy. The memory saving is unique to float8 inference as opposed to float8 training. For inference, the weights are static and thus do not need the higher precision during weight updates.

Proposal

Float8InferenceLinear Module

We propose a new Float8InferenceLinear module that extends nn.Linear with Float8 quantization capabilities:

class Float8InferenceLinear(torch.nn.Linear):
    def __init__(
        self,
        quant_config: QuantConfig,
        forward_config: ScaledMMConfig,
        scaling_granularity: Optional[ScalingGranularity],
        in_features: int,
        out_features: int,
        bias: bool = True,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        # ... implementation ...

This module handles the quantization of weights and activations based on the provided configuration. This module was landed in this PR: #287. It is designed to replace a pre-trained nn.Linear module in an existing model and statically convert the weight to FP8. By default, we do this in E4M3 format.

It provides configuration options via the QuantConfig class to encapsulate various quantization settings:

@dataclass(frozen=True)
class QuantConfig:
    activation_casting: ActivationCasting
    static_quantization_scale: Optional[torch.Tensor] = None
class ActivationCasting(Enum):
    """Types of quantization to perform on the activations

    WEIGHT_ONLY: Only quantize the weight, no activation casting, weight will be dequantized in the forward pass
    STATIC: Activation is quantized during model initialization with a static scale
    DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation
    """

    # TODO: A better name would be NONE, we should unify this with torchao
    WEIGHT_ONLY = auto()
    DYNAMIC = auto()
    STATIC = auto()

The main configuration options are captured in the ActivationCasting enum:

  • dynamic casting: This setting will cast the activation to nn.Linear to a scaled fp8 type during the forward pass and run _scaled_mm.
  • static quantization: This will use a pre-calculated activation scale for casting the activation to E4M3. Currently, we don't specify how this scale is to be calculated.
  • weight-only: This config stores weights in scaled fp8 but during inference we dequantize the float8 weight and run the matmul in the activation's precision.

Top-level API

We propose a top-level API for quantizing models:

def quantize_to_float8(
    module: nn.Module,
    quant_config: QuantConfig,
    *,
    skip_fqn_list: Optional[List[str]] = None,
    use_fast_accum: bool = True,
    scaling_granularity: Optional[ScalingGranularity] = None,  # Part of Future proposal
) -> Optional[nn.Module]:
    # ... implementation ...

This function allows users to easily convert their models to use Float8 inference.

An example of how this can be used on a Hugging Face model can be found in this PR in TorchAO

Proposed Extensions

Scaling Granularity

Currently, we only support TensorWise scaling. Concretely, this is done by calculating the max(abs(Tensor)) and utilizing this value to compute the Float8Tensor scale. However, due to outlier values in activations, this can have large quantization error. As well, calculating a global reduction across the entire activation tensor can be relatively slow.

Therefore, we want to add the option to specify different types of scaling granularities.

The scaling_granularity parameter determines how scales are computed:

  • TensorWise: A single scale is computed for the entire tensor.
  • AxisWise: Scales are computed along a specified axis of the tensor.

We recently added Axiswise scaling support to _scaled_mm in this PyTorch PR: #128989. As well, I have a worked PR stack showing how Axiswise scaling can be implemented in Float8Experimental: #305

We would like to continue generalizing the scaling granularity to:

  • GroupWise: Similar to AxisWise but instead of one scale per axis, we have multiple
  • BlockWise: All other forms can be seen as special cases of this. Scale per 2D tile of activation and weight.

Design Details

Tensor Subclass Usage

The implementation utilizes Float8Tensors to encapsulate the scaling as well as dispatch to _scaled_mm instead of torch.mm. This is not the only way this could be implemented. Since we do not have the autograd constraint that backpropagating grads must match the dtype of the tensor in the forward, we are free to desugar the Float8Tensor into its constituents, store them on the module, and use them in the forward. However, using the tensor subclass, allows us to re-use similar components between training and inference, but it does have downsides:

  • For people less familiar with Tensor Subclass, this indirection can be confusing.
  • This makes supporting torch.Export more challenging.

Performance

Compile

As with the rest of this project, we heavily rely on the compile stack to generate efficient and fused casting code. We do actually see some performance gains on heavily compute-bound models, but in general, we require torch.compile for competitive performance.

Export

Currently, it is not possible to run torch.export + AOTI with the publicly available export APIs. However, this PR: #295 demonstrates that it is possible. There are plans this half for the export team to make export of nn.modules with subclasses as weights available in the public API.

Limitations and Future Work

Extend ScalingGranularity
  • Limited Dtype Support: Currently, AxisWise scaling is only supported for bfloat16.
  • Kernel Support: _scaled_mm only supports TensorWise and AxisWise scaling; work is needed to extend to other granularities.
  • Existing Kernel Improvement: In early experiments, the AxisWise kernel is shown to not be as performant as the TensorWise kernel. Investigation is needed here.
  • Activation Folding: The current implementation doesn't support folding of leading dimensions for activations, which may be necessary for certain model architectures. Due to needing to calculate scales prior to construction of the Float8Tensor, we do not get to utilize the decomposition for Linear to do the unfolding for us.

Composition with other dtypes/techniques

  • There are various low-bit dtypes that users may want to use. TorchAO has a number of these - AffineQuantized, NF4Tensor, int4_weight_only, etc. It is possible that users will want to compose different types within the same model. Work is needed here to ensure that the top-level UX is expressive enough to handle these cases.

Standardize on TorchAO APIs

  • StaticCalibration Flow: This RFC does not aim to provide a static calibration flow. We would like to share this API with TorchAO. An early prototype of which can be found here: Add static quantization as an example for calibration flow pytorch/ao#487, and we should ensure we are composable.
  • TorchAO provides an autoquant API. This API relies on the entirety of the inference logic to be encapsulated in the subclass. That is not the case today, and work is needed to make sure that we can compose well here.

Non-H100 GPU Support

  • _scaled_mm's TensorWise support is enabled on sm89, and MI300x + GPUs. However, the AxisWise kernel is based on Cutlass and is not currently supported on any GPU besides H100.

Dynamic Shapes

  • Work is needed to validate that dynamic shapes is working as expected.

Other Module Support

-While Linear weights take up the majority of model size and compute, other operations can still be amenable to the compute gains from FP8

  • For transformer models we will likely need to support a fused FP8 SDPA variant. With the recent addition of the CuDNN SDPA to PyTorch core and the upgrade to CuDNN 9.1 we could utilize this library for this. However, work is needed to explore the various options.
  • The KVCache can contribute a significant proportion of the total memory usage during autoregressive decoding. We can investigate utilizing FP8 for storing quantized kv tokens.

Examples

# Example usage of the proposed API
model = MyLargeModel()
quant_config = QuantConfig(ActivationCasting.DYNAMIC)
quantized_model = quantize_to_float8(
    model,
    quant_config,
    scaling_granularity=ScalingGranularity.AxisWise
)

quantized_model = torch.compile(quantized_model)
# Use the quantized model for inference
input_tensor = torch.randn(1, 1024, 1024, dtype=torch.bfloat16, device="cuda")
output = quantized_model(input_tensor)

Open Questions

  1. Should we provide more granular control over which layers are quantized? This is possible today using FQNs but not sure if TorchAO has ideas on top-level UX.
  2. How can we best handle models with custom or non-standard linear layers?
  3. What additional tools or utilities might be needed to help users debug and optimize their quantized models?
  4. Quantization Error Reducing Techniques: Techniques like HQQ are utilized to reduce quantization error. It is unlikely that the existing _scaled_mm kernel can support this use case. Is that a problem?

Conclusion

This RFC proposes significant enhancements to Float8 inference in PyTorch, aiming to provide a more flexible, efficient, and user-friendly framework for quantization. By supporting various scaling granularities and quantization strategies, we can cater to a wide range of use cases and potentially unlock substantial performance improvements for many models.

Additional Details

Utilizing this script: https://gist.github.com/drisspg/d7ae2134fbb6ca369c4817853c3352fa

Results for batch_size=1, num_tokens=128:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |      211.2  | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |      151.04 | 1.40x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |      138.7  | 1.52x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |      460.01 | 0.46x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |      137.68 | 1.53x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |      131.39 | 1.61x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |      459.72 | 0.46x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+

Results for batch_size=1, num_tokens=1024:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |      642.22 | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |      396.68 | 1.62x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |      364.04 | 1.76x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |      871.38 | 0.74x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |      390.63 | 1.64x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |      369.72 | 1.74x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |      868.9  | 0.74x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+

Results for batch_size=32, num_tokens=128:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |     2567.15 | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |     1535.65 | 1.67x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |     1405.36 | 1.83x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |     2783.9  | 0.92x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |     1487.35 | 1.73x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |     1420.56 | 1.81x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |     2786.66 | 0.92x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+

Results for batch_size=32, num_tokens=1024:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |     21087.9 | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |     12172.4 | 1.73x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |     11220.7 | 1.88x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |     21209.9 | 0.99x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |     12393.6 | 1.70x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |     11853.9 | 1.78x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |     21227.7 | 0.99x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+

Results for batch_size=64, num_tokens=2048:
+----------------------------+-------------+-------------------+----------------+
| Variant                    |   Time (μs) | Speedup vs BF16   |   SQNR vs BF16 |
+============================+=============+===================+================+
| BF16                       |     86532.6 | 1.00x             |         inf    |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_TensorWise     |     49520.9 | 1.75x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_TensorWise      |     47816.8 | 1.81x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_TensorWise |     86674.2 | 1.00x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Dynamic_AxisWise       |     68645.7 | 1.26x             |          23.75 |
+----------------------------+-------------+-------------------+----------------+
| FP8_Static_AxisWise        |     54025.8 | 1.60x             |          23.5  |
+----------------------------+-------------+-------------------+----------------+
| FP8_Weight_Only_AxisWise   |     85562.3 | 1.01x             |          26.75 |
+----------------------------+-------------+-------------------+----------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions