|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from quant_primitives import ( |
| 4 | + dynamically_quantize_per_channel, |
| 5 | + quant_int8_dynamic_per_token_linear, |
| 6 | +) |
| 7 | + |
| 8 | +__all__ = ["DynamicallyPerAxisQuantizedLinear"] |
| 9 | + |
| 10 | + |
| 11 | +class DynamicallyPerAxisQuantizedLinear(torch.nn.Linear): |
| 12 | + """ |
| 13 | + This class is a replacement for `torch.nn.Linear`, implementing dynamic quantization on |
| 14 | + the input across all axes except for the last axis. |
| 15 | + """ |
| 16 | + |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + in_features: int, |
| 20 | + out_features: int, |
| 21 | + bias: bool = True, |
| 22 | + use_fused_int_mm=False, |
| 23 | + ) -> None: |
| 24 | + super().__init__(in_features, out_features, bias) |
| 25 | + self.use_fused_int_mm = use_fused_int_mm |
| 26 | + # note: enabling use_fused_int_mm = True has best perf when additionally setting |
| 27 | + # torch._inductor.config.force_fuse_int_mm_with_mul = True |
| 28 | + |
| 29 | + def forward(self, X: torch.Tensor) -> torch.Tensor: |
| 30 | + """ |
| 31 | + Performs the forward pass of the quantized linear layer. |
| 32 | +
|
| 33 | + This method applies dynamic quantization to the input tensor across all axes except |
| 34 | + the last axis using the `quant_int8_dynamic_per_token_linear` function. |
| 35 | +
|
| 36 | + Args: |
| 37 | + X (torch.Tensor): The input tensor to the quantized linear layer. |
| 38 | +
|
| 39 | + Returns: |
| 40 | + torch.Tensor: The output tensor after the quantized matmul and rescale. |
| 41 | +
|
| 42 | + """ |
| 43 | + # The following line mimics the behavior of SmoothFakeDynamicallyQuantizedLinear |
| 44 | + if not self.use_fused_int_mm: |
| 45 | + X = X / self.fake_rescale |
| 46 | + # somehow the inductor fusion that occurs for most transformer models |
| 47 | + # when this module has an additional div op is faster than when it doesn't |
| 48 | + # have it although the memory usage is slightly higher. fake_rescale is scalar 1 |
| 49 | + # so it doesn't affect accuracy |
| 50 | + Y = quant_int8_dynamic_per_token_linear( |
| 51 | + X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype |
| 52 | + ) |
| 53 | + return Y |
| 54 | + |
| 55 | + @classmethod |
| 56 | + def from_float( |
| 57 | + cls, mod: torch.nn.Linear, use_fused_int_mm=False |
| 58 | + ) -> "DynamicallyPerAxisQuantizedLinear": |
| 59 | + """ |
| 60 | + Converts a `mod` of class `torch.nn.Linear` to the dynamically quantized version of it. |
| 61 | +
|
| 62 | + Note: this class does not require calibration. |
| 63 | +
|
| 64 | + Args: |
| 65 | + mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. |
| 66 | +
|
| 67 | + Returns: |
| 68 | + DynamicallyPerAxisQuantizedLinear: The converted quantized linear module. |
| 69 | +
|
| 70 | + """ |
| 71 | + |
| 72 | + # create the new module with a toy size to ensure initialization is fast |
| 73 | + fake_in_features, fake_out_features = 8, 8 |
| 74 | + new_mod = cls( |
| 75 | + fake_in_features, |
| 76 | + fake_out_features, |
| 77 | + bias=mod.bias is not None, |
| 78 | + use_fused_int_mm=use_fused_int_mm, |
| 79 | + ) |
| 80 | + new_mod.in_features = mod.in_features |
| 81 | + new_mod.out_features = mod.out_features |
| 82 | + W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel( |
| 83 | + mod.weight, -128, 127, torch.int8 |
| 84 | + ) |
| 85 | + new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t()) |
| 86 | + new_mod.W_scales = nn.Parameter(W_scales) |
| 87 | + new_mod.bias = mod.bias |
| 88 | + if not use_fused_int_mm: |
| 89 | + new_mod.fake_rescale = torch.tensor( |
| 90 | + [1.0], dtype=mod.weight.dtype, device=mod.weight.device |
| 91 | + ) |
| 92 | + del new_mod.weight |
| 93 | + |
| 94 | + device_to_use = next(mod.parameters()).device |
| 95 | + new_mod.to(device_to_use) |
| 96 | + return new_mod |
0 commit comments