Skip to content

Commit a753e3f

Browse files
committed
Adding gpu quantization workflows and apis
Summary: Apis and workflows used for quantization and pruning in the segment-anything-fast and gpt-fast repos. Test Plan: python /home/cdhernandez/local/ao/ao/quantization/test.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 31191a786cb43d31f37b6d77121c8e4882ded037 Pull Request resolved: #1
1 parent 7b3330c commit a753e3f

17 files changed

+2132
-0
lines changed

ao/quantization/__init__.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from smoothquant import * # noqa: F403
2+
from quant_api import * # noqa: F403
3+
from subclass import * # noqa: F403
4+
from quant_primitives import * # noqa: F403
5+
from utils import * # noqa: F403
6+
from weight_only import * # noqa: F403
7+
8+
__all__ = [
9+
"DynamicallyPerAxisQuantizedLinear",
10+
"replace_with_custom_fn_if_matches_filter",
11+
"apply_weight_only_int8_quant",
12+
"apply_dynamic_quant",
13+
"change_linear_weights_to_dqtensors",
14+
"insert_subclass",
15+
"safe_int_mm",
16+
"dynamically_quantize_per_tensor",
17+
"quantize_activation_per_token_absmax",
18+
"dynamically_quantize_per_channel",
19+
"dequantize_per_tensor",
20+
"dequantize_per_channel",
21+
"quant_int8_dynamic_linear",
22+
"quant_int8_matmul",
23+
"quant_int8_dynamic_per_token_linear",
24+
"quant_int8_per_token_matmul",
25+
"get_scale",
26+
"SmoothFakeDynQuantMixin",
27+
"SmoothFakeDynamicallyQuantizedLinear",
28+
"swap_linear_with_smooth_fq_linear",
29+
"smooth_fq_linear_to_inference",
30+
"set_smooth_fq_attribute",
31+
"DynamicallyQuantizedLinearWeight",
32+
"log_with_rank",
33+
"clear_logs",
34+
"compute_error",
35+
"forward_hook",
36+
"apply_logging_hook",
37+
"get_model_size_in_bytes",
38+
"WeightOnlyInt8QuantLinear",
39+
]
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

ao/quantization/dynamic_quant.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import torch
2+
import torch.nn as nn
3+
from quant_primitives import (
4+
dynamically_quantize_per_channel,
5+
quant_int8_dynamic_per_token_linear,
6+
)
7+
8+
__all__ = ["DynamicallyPerAxisQuantizedLinear"]
9+
10+
11+
class DynamicallyPerAxisQuantizedLinear(torch.nn.Linear):
12+
"""
13+
This class is a replacement for `torch.nn.Linear`, implementing dynamic quantization on
14+
the input across all axes except for the last axis.
15+
"""
16+
17+
def __init__(
18+
self,
19+
in_features: int,
20+
out_features: int,
21+
bias: bool = True,
22+
use_fused_int_mm=False,
23+
) -> None:
24+
super().__init__(in_features, out_features, bias)
25+
self.use_fused_int_mm = use_fused_int_mm
26+
# note: enabling use_fused_int_mm = True has best perf when additionally setting
27+
# torch._inductor.config.force_fuse_int_mm_with_mul = True
28+
29+
def forward(self, X: torch.Tensor) -> torch.Tensor:
30+
"""
31+
Performs the forward pass of the quantized linear layer.
32+
33+
This method applies dynamic quantization to the input tensor across all axes except
34+
the last axis using the `quant_int8_dynamic_per_token_linear` function.
35+
36+
Args:
37+
X (torch.Tensor): The input tensor to the quantized linear layer.
38+
39+
Returns:
40+
torch.Tensor: The output tensor after the quantized matmul and rescale.
41+
42+
"""
43+
# The following line mimics the behavior of SmoothFakeDynamicallyQuantizedLinear
44+
if not self.use_fused_int_mm:
45+
X = X / self.fake_rescale
46+
# somehow the inductor fusion that occurs for most transformer models
47+
# when this module has an additional div op is faster than when it doesn't
48+
# have it although the memory usage is slightly higher. fake_rescale is scalar 1
49+
# so it doesn't affect accuracy
50+
Y = quant_int8_dynamic_per_token_linear(
51+
X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype
52+
)
53+
return Y
54+
55+
@classmethod
56+
def from_float(
57+
cls, mod: torch.nn.Linear, use_fused_int_mm=False
58+
) -> "DynamicallyPerAxisQuantizedLinear":
59+
"""
60+
Converts a `mod` of class `torch.nn.Linear` to the dynamically quantized version of it.
61+
62+
Note: this class does not require calibration.
63+
64+
Args:
65+
mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert.
66+
67+
Returns:
68+
DynamicallyPerAxisQuantizedLinear: The converted quantized linear module.
69+
70+
"""
71+
72+
# create the new module with a toy size to ensure initialization is fast
73+
fake_in_features, fake_out_features = 8, 8
74+
new_mod = cls(
75+
fake_in_features,
76+
fake_out_features,
77+
bias=mod.bias is not None,
78+
use_fused_int_mm=use_fused_int_mm,
79+
)
80+
new_mod.in_features = mod.in_features
81+
new_mod.out_features = mod.out_features
82+
W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel(
83+
mod.weight, -128, 127, torch.int8
84+
)
85+
new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t())
86+
new_mod.W_scales = nn.Parameter(W_scales)
87+
new_mod.bias = mod.bias
88+
if not use_fused_int_mm:
89+
new_mod.fake_rescale = torch.tensor(
90+
[1.0], dtype=mod.weight.dtype, device=mod.weight.device
91+
)
92+
del new_mod.weight
93+
94+
device_to_use = next(mod.parameters()).device
95+
new_mod.to(device_to_use)
96+
return new_mod

ao/quantization/quant_api.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Quantization API stuff which is not specific to SmoothQuant
3+
4+
Note: this is throwaway code for fast results on Blueberry, this is not
5+
intended to be the actual long term quantization API for server GPUs.
6+
"""
7+
8+
import torch
9+
from dynamic_quant import (
10+
DynamicallyPerAxisQuantizedLinear,
11+
)
12+
from subclass import (
13+
DynamicallyQuantizedLinearWeight,
14+
)
15+
from weight_only import (
16+
WeightOnlyInt8QuantLinear,
17+
)
18+
19+
__all__ = [
20+
"replace_with_custom_fn_if_matches_filter",
21+
"apply_weight_only_int8_quant",
22+
"apply_dynamic_quant",
23+
"change_linear_weights_to_dqtensors",
24+
]
25+
26+
27+
def replace_with_custom_fn_if_matches_filter(
28+
model, replacement_fn, filter_fn, cur_fqn=""
29+
) -> None:
30+
"""
31+
For each `child` in `model`, replaces it with `replacement_fn(child)`
32+
if `filter_fn(child)` is `True`
33+
"""
34+
name_to_child = dict(model.named_children())
35+
for name, child in name_to_child.items():
36+
if cur_fqn == "":
37+
new_fqn = name
38+
else:
39+
new_fqn = f"{cur_fqn}.{name}"
40+
if filter_fn(child, new_fqn):
41+
new_child = replacement_fn(child)
42+
setattr(model, name, new_child)
43+
else:
44+
replace_with_custom_fn_if_matches_filter(
45+
child, replacement_fn, filter_fn, new_fqn
46+
)
47+
48+
49+
def apply_weight_only_int8_quant(model):
50+
replace_with_custom_fn_if_matches_filter(
51+
model,
52+
WeightOnlyInt8QuantLinear.from_float,
53+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
54+
)
55+
56+
57+
def apply_dynamic_quant(model, use_fused_int_mm=0):
58+
replace_with_custom_fn_if_matches_filter(
59+
model,
60+
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod, use_fused_int_mm),
61+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
62+
)
63+
64+
65+
def change_linear_weights_to_dqtensors(model):
66+
def insert_subclass(lin):
67+
lin.weight = torch.nn.Parameter(
68+
DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False
69+
)
70+
return lin
71+
72+
replace_with_custom_fn_if_matches_filter(
73+
model, insert_subclass, lambda mod, fqn: isinstance(mod, torch.nn.Linear)
74+
)

0 commit comments

Comments
 (0)