|
| 1 | +""" |
| 2 | +Demo for static quantization flow |
| 3 | +""" |
| 4 | +import torch |
| 5 | +import copy |
| 6 | + |
| 7 | +# TODO: use the generalized observer for affine qunatization in the future |
| 8 | +from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver |
| 9 | +import torch.nn.functional as F |
| 10 | +from torch import Tensor |
| 11 | +from torchao.dtypes import to_affine_quantized_static |
| 12 | +from torchao.quantization.utils import compute_error |
| 13 | +from torchao.quantization import quantize_ |
| 14 | +from torchao.quantization.subclass import to_linear_act_quantized |
| 15 | +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter |
| 16 | + |
| 17 | + |
| 18 | + |
| 19 | +class ObservedLinear(torch.nn.Linear): |
| 20 | + def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): |
| 21 | + super().__init__(in_features, out_features, bias, device, dtype) |
| 22 | + self.act_obs = act_obs |
| 23 | + self.weight_obs = weight_obs |
| 24 | + |
| 25 | + def forward(self, input: Tensor): |
| 26 | + observed_input = self.act_obs(input) |
| 27 | + observed_weight = self.weight_obs(self.weight) |
| 28 | + return F.linear(observed_input, observed_weight, self.bias) |
| 29 | + |
| 30 | + @classmethod |
| 31 | + def from_float(cls, float_linear, act_obs, weight_obs): |
| 32 | + observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) |
| 33 | + observed_linear.weight = float_linear.weight |
| 34 | + observed_linear.bias = float_linear.bias |
| 35 | + return observed_linear |
| 36 | + |
| 37 | +def insert_observers_(model, act_obs, weight_obs): |
| 38 | + _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) |
| 39 | + replacement_fn = lambda m: ObservedLinear.from_float(m, act_obs, weight_obs) |
| 40 | + act_obs = copy.deepcopy(act_obs) |
| 41 | + weight_obs = copy.deepcopy(weight_obs) |
| 42 | + _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) |
| 43 | + |
| 44 | +# converting observed linear module to linear module with quantzied weights (and quantized activations) |
| 45 | +# with tensor subclasses |
| 46 | +def apply_static_quant(observed_linear): |
| 47 | + target_dtype = torch.uint8 |
| 48 | + |
| 49 | + # weight quantization |
| 50 | + weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() |
| 51 | + def weight_quant_func(weight): |
| 52 | + block_size = (1, weight.shape[1]) |
| 53 | + return to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) |
| 54 | + linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) |
| 55 | + linear.weight = observed_linear.weight |
| 56 | + linear.bias = observed_linear.bias |
| 57 | + |
| 58 | + linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False) |
| 59 | + |
| 60 | + # activation quantization |
| 61 | + act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() |
| 62 | + input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype) |
| 63 | + linear.weight = torch.nn.Parameter(to_linear_act_quantized(linear.weight, input_quant_func), requires_grad=False) |
| 64 | + |
| 65 | + return linear |
| 66 | + |
| 67 | + |
| 68 | +# alternative for converting observed linear module to quantized linear module |
| 69 | +class QuantizedLinear(torch.nn.Module): |
| 70 | + def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor): |
| 71 | + super().__init__() |
| 72 | + self.act_scale, self.act_zero_point = act_obs.calculate_qparams() |
| 73 | + weight_scale, weight_zero_point = weight_obs.calculate_qparams() |
| 74 | + assert weight.dim() == 2 |
| 75 | + block_size = (1, weight.shape[1]) |
| 76 | + target_dtype = torch.uint8 |
| 77 | + self.qweight = to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) |
| 78 | + self.bias = bias |
| 79 | + |
| 80 | + def forward(self, input: Tensor): |
| 81 | + block_size = input.shape |
| 82 | + target_dtype = torch.uint8 |
| 83 | + qinput = to_affine_quantized_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype) |
| 84 | + return F.linear(qinput, self.qweight, self.bias) |
| 85 | + |
| 86 | + @classmethod |
| 87 | + def from_observed(cls, observed_linear): |
| 88 | + quantized_linear = cls(observed_linear.in_features, observed_linear.out_features, observed_linear.act_obs, observed_linear.weight_obs, observed_linear.weight, observed_linear.bias) |
| 89 | + return quantized_linear |
| 90 | + |
| 91 | +def apply_static_quant2(observed_linear): |
| 92 | + return QuantizedLinear.from_observed(observed_linear) |
| 93 | + |
| 94 | +class ToyLinearModel(torch.nn.Module): |
| 95 | + def __init__(self, m=64, n=32, k=64): |
| 96 | + super().__init__() |
| 97 | + self.linear1 = torch.nn.Linear(m, n, bias=False) |
| 98 | + self.linear2 = torch.nn.Linear(n, k, bias=False) |
| 99 | + |
| 100 | + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): |
| 101 | + return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) |
| 102 | + |
| 103 | + def forward(self, x): |
| 104 | + x = self.linear1(x) |
| 105 | + x = self.linear2(x) |
| 106 | + return x |
| 107 | + |
| 108 | +dtype = torch.bfloat16 |
| 109 | +m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") |
| 110 | +m_bf16 = copy.deepcopy(m) |
| 111 | +example_inputs = m.example_inputs(dtype=dtype, device="cuda") |
| 112 | + |
| 113 | +m_bf16 = torch.compile(m_bf16, mode='max-autotune') |
| 114 | + |
| 115 | +# TODO: use the generalized observer for affine qunatization in the future |
| 116 | +act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda") |
| 117 | +weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda") |
| 118 | + |
| 119 | +before_quant = m(*example_inputs) |
| 120 | + |
| 121 | +insert_observers_(m, act_obs, weight_obs) |
| 122 | +# calibrating / training |
| 123 | +for _ in range(10): |
| 124 | + m(*example_inputs) |
| 125 | + |
| 126 | +after_obs = m(*example_inputs) |
| 127 | + |
| 128 | +m2 = copy.deepcopy(m) |
| 129 | + |
| 130 | +is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) |
| 131 | + |
| 132 | +# quantized linear represented as an nn.Linear with modified tensor subclass weights |
| 133 | +# for both activation and weight quantization |
| 134 | +quantize_(m, apply_static_quant, is_observed_linear) |
| 135 | +print("quantized model (applying tensor subclass to weight):", m) |
| 136 | +after_quant = m(*example_inputs) |
| 137 | +assert compute_error(before_quant, after_quant) > 30 |
| 138 | +print("test passed") |
| 139 | + |
| 140 | +# quantized linear as a standalone module |
| 141 | +quantize_(m2, apply_static_quant2, is_observed_linear) |
| 142 | +print("quantized model (quantized module):", m2) |
| 143 | +after_quant = m2(*example_inputs) |
| 144 | +assert compute_error(before_quant, after_quant) > 30 |
| 145 | +print("test passed") |
0 commit comments