-
Notifications
You must be signed in to change notification settings - Fork 293
Add static quantization as an example for calibration flow #487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -259,12 +259,12 @@ def insert_subclass(lin): | |
|
||
return insert_subclass | ||
|
||
def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): | ||
def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @msaroufim I made some changes to |
||
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace | ||
|
||
Args: | ||
model (torch.nn.Module): input model | ||
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance) | ||
apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor) | ||
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on | ||
the weight of the module | ||
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) | ||
|
@@ -300,19 +300,24 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Ten | |
x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6, | ||
zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float") | ||
|
||
def apply_weight_quant_to_linear(linear): | ||
linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False) | ||
return linear | ||
|
||
# apply to modules under block0 submodule | ||
def filter_fn(module: nn.Module, fqn: str) -> bool: | ||
return isinstance(module, nn.Linear) | ||
|
||
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) | ||
quantize_(m, apply_weight_quant, filter_fn) | ||
quantize_(m, apply_weight_quant_to_linear, filter_fn) | ||
|
||
""" | ||
if set_inductor_config: | ||
torchao.quantization.utils.recommended_inductor_config_setter() | ||
|
||
_replace_with_custom_fn_if_matches_filter( | ||
model, | ||
_get_linear_subclass_inserter(apply_tensor_subclass), | ||
apply_tensor_subclass, | ||
_is_linear if filter_fn is None else filter_fn, | ||
) | ||
|
||
|
@@ -356,7 +361,7 @@ def get_per_token_block_size(x): | |
weight = to_linear_act_quantized(weight, input_quant_func) | ||
return weight | ||
|
||
return apply_int8_dynamic_activation_int4_weight_quant | ||
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant) | ||
|
||
|
||
def int4_weight_only(group_size=128, inner_k_tiles=8): | ||
|
@@ -394,7 +399,7 @@ def apply_int4_weight_only_quant(weight): | |
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles) | ||
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) | ||
|
||
return apply_int4_weight_only_quant | ||
return _get_linear_subclass_inserter(apply_int4_weight_only_quant) | ||
|
||
|
||
def int8_weight_only(): | ||
|
@@ -412,7 +417,7 @@ def apply_int8wo_quant(weight): | |
block_size = (1, weight.shape[1]) | ||
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) | ||
|
||
return apply_int8wo_quant | ||
return _get_linear_subclass_inserter(apply_int8wo_quant) | ||
|
||
def int8_dynamic_activation_int8_weight(): | ||
""" | ||
|
@@ -454,4 +459,4 @@ def get_per_token_block_size(x): | |
weight = to_linear_act_quantized(weight, input_quant_func) | ||
return weight | ||
|
||
return apply_int8_dynamic_activation_int8_weight_quant | ||
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
""" | ||
Demo for static quantization flow | ||
""" | ||
import torch | ||
import copy | ||
|
||
# TODO: use the generalized observer for affine qunatization in the future | ||
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
from torchao.dtypes import to_affine_quantized_static | ||
from torchao.quantization.utils import compute_error | ||
from torchao.quantization import quantize_ | ||
from torchao.quantization.subclass import to_linear_act_quantized | ||
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter | ||
|
||
|
||
|
||
class ObservedLinear(torch.nn.Linear): | ||
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): | ||
super().__init__(in_features, out_features, bias, device, dtype) | ||
self.act_obs = act_obs | ||
self.weight_obs = weight_obs | ||
|
||
def forward(self, input: Tensor): | ||
observed_input = self.act_obs(input) | ||
observed_weight = self.weight_obs(self.weight) | ||
return F.linear(observed_input, observed_weight, self.bias) | ||
|
||
@classmethod | ||
def from_float(cls, float_linear, act_obs, weight_obs): | ||
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) | ||
observed_linear.weight = float_linear.weight | ||
observed_linear.bias = float_linear.bias | ||
return observed_linear | ||
|
||
def insert_observers_(model, act_obs, weight_obs): | ||
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) | ||
replacement_fn = lambda m: ObservedLinear.from_float(m, act_obs, weight_obs) | ||
act_obs = copy.deepcopy(act_obs) | ||
weight_obs = copy.deepcopy(weight_obs) | ||
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) | ||
|
||
# converting observed linear module to linear module with quantzied weights (and quantized activations) | ||
# with tensor subclasses | ||
def apply_static_quant(observed_linear): | ||
target_dtype = torch.uint8 | ||
|
||
# weight quantization | ||
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() | ||
def weight_quant_func(weight): | ||
block_size = (1, weight.shape[1]) | ||
return to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) | ||
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) | ||
linear.weight = observed_linear.weight | ||
linear.bias = observed_linear.bias | ||
|
||
linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False) | ||
|
||
# activation quantization | ||
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() | ||
input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype) | ||
linear.weight = torch.nn.Parameter(to_linear_act_quantized(linear.weight, input_quant_func), requires_grad=False) | ||
|
||
return linear | ||
|
||
|
||
# alternative for converting observed linear module to quantized linear module | ||
class QuantizedLinear(torch.nn.Module): | ||
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor): | ||
super().__init__() | ||
self.act_scale, self.act_zero_point = act_obs.calculate_qparams() | ||
weight_scale, weight_zero_point = weight_obs.calculate_qparams() | ||
assert weight.dim() == 2 | ||
block_size = (1, weight.shape[1]) | ||
target_dtype = torch.uint8 | ||
self.qweight = to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) | ||
self.bias = bias | ||
|
||
def forward(self, input: Tensor): | ||
block_size = input.shape | ||
target_dtype = torch.uint8 | ||
qinput = to_affine_quantized_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype) | ||
return F.linear(qinput, self.qweight, self.bias) | ||
|
||
@classmethod | ||
def from_observed(cls, observed_linear): | ||
quantized_linear = cls(observed_linear.in_features, observed_linear.out_features, observed_linear.act_obs, observed_linear.weight_obs, observed_linear.weight, observed_linear.bias) | ||
return quantized_linear | ||
|
||
def apply_static_quant2(observed_linear): | ||
return QuantizedLinear.from_observed(observed_linear) | ||
|
||
class ToyLinearModel(torch.nn.Module): | ||
def __init__(self, m=64, n=32, k=64): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(m, n, bias=False) | ||
self.linear2 = torch.nn.Linear(n, k, bias=False) | ||
|
||
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): | ||
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
return x | ||
|
||
dtype = torch.bfloat16 | ||
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") | ||
m_bf16 = copy.deepcopy(m) | ||
example_inputs = m.example_inputs(dtype=dtype, device="cuda") | ||
|
||
m_bf16 = torch.compile(m_bf16, mode='max-autotune') | ||
|
||
# TODO: use the generalized observer for affine qunatization in the future | ||
act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda") | ||
weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda") | ||
|
||
before_quant = m(*example_inputs) | ||
|
||
insert_observers_(m, act_obs, weight_obs) | ||
# calibrating / training | ||
for _ in range(10): | ||
m(*example_inputs) | ||
|
||
after_obs = m(*example_inputs) | ||
|
||
m2 = copy.deepcopy(m) | ||
|
||
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) | ||
|
||
# quantized linear represented as an nn.Linear with modified tensor subclass weights | ||
# for both activation and weight quantization | ||
quantize_(m, apply_static_quant, is_observed_linear) | ||
print("quantized model (applying tensor subclass to weight):", m) | ||
after_quant = m(*example_inputs) | ||
assert compute_error(before_quant, after_quant) > 30 | ||
print("test passed") | ||
|
||
# quantized linear as a standalone module | ||
quantize_(m2, apply_static_quant2, is_observed_linear) | ||
print("quantized model (quantized module):", m2) | ||
after_quant = m2(*example_inputs) | ||
assert compute_error(before_quant, after_quant) > 30 | ||
print("test passed") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.