|
| 1 | +import types |
| 2 | +from dataclasses import dataclass |
1 | 3 | from typing import Dict, Optional
|
2 | 4 |
|
3 | 5 | import torch
|
4 | 6 |
|
| 7 | +from torchao.core.config import AOBaseConfig |
5 | 8 | from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static
|
6 | 9 | from torchao.prototype.smoothquant.core import (
|
7 | 10 | SmoothQuantObservedLinear,
|
8 | 11 | SmoothQuantObserver,
|
9 | 12 | )
|
| 13 | +from torchao.quantization import quantize_ |
10 | 14 | from torchao.quantization.linear_activation_quantized_tensor import (
|
11 | 15 | to_linear_activation_quantized,
|
12 | 16 | )
|
13 | 17 | from torchao.quantization.linear_activation_scale import (
|
14 | 18 | to_weight_tensor_with_linear_activation_scale_metadata,
|
15 | 19 | )
|
16 |
| -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter |
| 20 | +from torchao.quantization.quant_api import ( |
| 21 | + _linear_extra_repr, |
| 22 | + _replace_with_custom_fn_if_matches_filter, |
| 23 | +) |
17 | 24 | from torchao.quantization.quant_primitives import MappingType
|
| 25 | +from torchao.quantization.transform_module import ( |
| 26 | + register_quantize_module_handler, |
| 27 | +) |
18 | 28 | from torchao.quantization.utils import _get_per_token_block_size
|
19 | 29 | from torchao.quantization.weight_tensor_linear_activation_quantization import (
|
20 | 30 | to_weight_tensor_with_linear_activation_quantization_metadata,
|
@@ -53,32 +63,6 @@ def replace_with_observer(layer):
|
53 | 63 | _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
|
54 | 64 |
|
55 | 65 |
|
56 |
| -def _observed_linear_subclass_inserter(constructor): |
57 |
| - """ |
58 |
| - Replaces unquantized observed linear instances with quantized linear instances. |
59 |
| -
|
60 |
| - Args: |
61 |
| - constructor: the function which applies quantization to the observed linear layer |
62 |
| - """ |
63 |
| - |
64 |
| - def insert_subclass(observed_linear): |
65 |
| - # creates the new linear layer using constructor |
66 |
| - linear = torch.nn.Linear( |
67 |
| - observed_linear.in_features, |
68 |
| - observed_linear.out_features, |
69 |
| - observed_linear.bias is not None, |
70 |
| - device=observed_linear.weight.device, |
71 |
| - dtype=observed_linear.weight.dtype, |
72 |
| - ) |
73 |
| - linear.weight = torch.nn.Parameter( |
74 |
| - constructor(observed_linear), requires_grad=False |
75 |
| - ) |
76 |
| - linear.bias = observed_linear.bias |
77 |
| - return linear |
78 |
| - |
79 |
| - return insert_subclass |
80 |
| - |
81 |
| - |
82 | 66 | def save_smooth_quant_recipe(
|
83 | 67 | model: torch.nn.Module, save_path: str
|
84 | 68 | ) -> Dict[str, torch.Tensor]:
|
@@ -121,7 +105,14 @@ def recurse(module: torch.nn.Module, name: str = ""):
|
121 | 105 | # act_scales is None for dynamic quantization
|
122 | 106 | if any(x is None for x in (smoothing_factor, wei_scales)):
|
123 | 107 | return module
|
124 |
| - return smooth_quant(smoothing_factor, act_scales, wei_scales)(module) |
| 108 | + is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) |
| 109 | + wrapper = torch.nn.Sequential(module) |
| 110 | + quantize_( |
| 111 | + wrapper, |
| 112 | + SmoothQuantConfig(smoothing_factor, act_scales, wei_scales), |
| 113 | + is_observed_linear, |
| 114 | + ) |
| 115 | + return wrapper[0] |
125 | 116 |
|
126 | 117 | mod_new = module
|
127 | 118 |
|
@@ -158,54 +149,73 @@ def static_quantize(self, input, scale, zero_point):
|
158 | 149 | )
|
159 | 150 |
|
160 | 151 |
|
161 |
| -def smooth_quant( |
162 |
| - smoothing_factor: Optional[torch.Tensor] = None, |
163 |
| - act_scales: Optional[torch.Tensor] = None, |
164 |
| - wei_scales: Optional[torch.Tensor] = None, |
165 |
| -): |
| 152 | +@dataclass |
| 153 | +class SmoothQuantConfig(AOBaseConfig): |
166 | 154 | """
|
167 |
| - Quantizes linear layers when passed into quantize_() |
| 155 | + Configuration for quantizing linear layers when passed into quantize_() |
168 | 156 |
|
169 | 157 | Args:
|
170 | 158 | smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None.
|
171 | 159 | act_scales: The activation scales for the layer. Acquired from the layer's observer if None.
|
172 | 160 | wei_scales: The weight scales for the layer. Acquired from the layer's observer if None.
|
173 | 161 | """
|
174 | 162 |
|
175 |
| - def quantize_weight(observed_linear): |
176 |
| - target_dtype = torch.int8 |
177 |
| - # act_scales is None for dynamic quantization thus not checked |
178 |
| - if any(x is None for x in (smoothing_factor, wei_scales)): |
179 |
| - factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() |
180 |
| - weight = observed_linear.obs.weight * factor |
181 |
| - else: |
182 |
| - factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales |
183 |
| - weight = observed_linear.weight * factor |
184 |
| - weight = weight.to(observed_linear.weight.dtype) |
185 |
| - block_size = (1, weight.size(1)) |
186 |
| - wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64) |
187 |
| - qw = to_affine_quantized_intx_static( |
188 |
| - weight, |
189 |
| - w_scales, |
190 |
| - wei_zero_points, |
191 |
| - block_size, |
192 |
| - target_dtype, |
193 |
| - ) |
| 163 | + smoothing_factor: Optional[torch.Tensor] = None |
| 164 | + act_scales: Optional[torch.Tensor] = None |
| 165 | + wei_scales: Optional[torch.Tensor] = None |
194 | 166 |
|
195 |
| - if x_scale is None: |
196 |
| - # dynamic quant |
197 |
| - qw = to_linear_activation_quantized( |
198 |
| - qw, _ActQuantizer(target_dtype).dynamic_quantize |
199 |
| - ) |
200 |
| - else: |
201 |
| - # static quant |
202 |
| - x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64) |
203 |
| - qw = to_weight_tensor_with_linear_activation_quantization_metadata( |
204 |
| - qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point |
205 |
| - ) |
206 | 167 |
|
207 |
| - return to_weight_tensor_with_linear_activation_scale_metadata( |
208 |
| - qw, factor.to(qw.dtype) |
| 168 | +@register_quantize_module_handler(SmoothQuantConfig) |
| 169 | +def _smooth_quant_transform( |
| 170 | + module: torch.nn.Module, |
| 171 | + config: SmoothQuantConfig, |
| 172 | +): |
| 173 | + smoothing_factor = config.smoothing_factor |
| 174 | + act_scales = config.act_scales |
| 175 | + wei_scales = config.wei_scales |
| 176 | + observed_linear = module |
| 177 | + |
| 178 | + linear = torch.nn.Linear( |
| 179 | + observed_linear.in_features, |
| 180 | + observed_linear.out_features, |
| 181 | + observed_linear.bias is not None, |
| 182 | + device=observed_linear.weight.device, |
| 183 | + dtype=observed_linear.weight.dtype, |
| 184 | + ) |
| 185 | + linear.bias = observed_linear.bias |
| 186 | + |
| 187 | + target_dtype = torch.int8 |
| 188 | + # act_scales is None for dynamic quantization thus not checked |
| 189 | + if any(x is None for x in (smoothing_factor, wei_scales)): |
| 190 | + factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() |
| 191 | + weight = observed_linear.obs.weight * factor |
| 192 | + else: |
| 193 | + factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales |
| 194 | + weight = observed_linear.weight * factor |
| 195 | + weight = weight.to(observed_linear.weight.dtype) |
| 196 | + block_size = (1, weight.size(1)) |
| 197 | + wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64) |
| 198 | + qw = to_affine_quantized_intx_static( |
| 199 | + weight, |
| 200 | + w_scales, |
| 201 | + wei_zero_points, |
| 202 | + block_size, |
| 203 | + target_dtype, |
| 204 | + ) |
| 205 | + |
| 206 | + if x_scale is None: |
| 207 | + # dynamic quant |
| 208 | + qw = to_linear_activation_quantized( |
| 209 | + qw, _ActQuantizer(target_dtype).dynamic_quantize |
| 210 | + ) |
| 211 | + else: |
| 212 | + # static quant |
| 213 | + x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64) |
| 214 | + qw = to_weight_tensor_with_linear_activation_quantization_metadata( |
| 215 | + qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point |
209 | 216 | )
|
210 | 217 |
|
211 |
| - return _observed_linear_subclass_inserter(quantize_weight) |
| 218 | + qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype)) |
| 219 | + linear.weight = torch.nn.Parameter(qw, requires_grad=False) |
| 220 | + linear.extra_repr = types.MethodType(_linear_extra_repr, module) |
| 221 | + return linear |
0 commit comments