Skip to content

Commit

Permalink
Add Int8DynActInt8WeightLinear module (#5605)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5605

Adding Int8DynActInt8WeightLinear for Per Channel DQ Linear

Reviewed By: mergennachin

Differential Revision: D63339550
  • Loading branch information
mcr229 authored and facebook-github-bot committed Sep 25, 2024
1 parent 5c56f96 commit c8267ba
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions examples/models/llama2/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,99 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
# return F.linear(input, self.weight.to(dtype=input.dtype)) * se...


def linear_forward_8da8w(
x,
weight_int8,
scales,
zeros,
out_features,
precision,
):
from torchao.quantization.utils import per_token_dynamic_quant

x = per_token_dynamic_quant(x)
n_bit = 8
quant_min = -(2 ** (n_bit - 1))
quant_max = 2 ** (n_bit - 1) - 1
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel(
weight_int8,
scales,
zeros,
0,
quant_min,
quant_max,
torch.int8,
out_dtype=precision,
)
c = torch.nn.functional.linear(x, w_dq)

return c


class Int8DynActInt8WeightLinear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]

in_features: int
out_features: int
weight: torch.Tensor

"""
This module implements a dynamic quantized linear layer with int8 weight.
Weights are per channel quantized. Parameters of importance
precision: precision of input and output. e.g. torch.float32 means input
activation is float32 and output is float32.
"""

def __init__(
self,
in_features: int,
out_features: int,
bias=True,
device=None,
dtype=None,
precision: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.precision = precision

if dtype is not None:
raise ValueError("Please specify 'precision' instead of 'dtype'")

# currently storing unpacked int8 weights
self.register_buffer(
"weight",
torch.empty((out_features, in_features), dtype=torch.int8),
)
self.register_buffer(
"scales",
torch.empty(
(out_features),
dtype=torch.float32,
),
)
self.register_buffer(
"zeros",
torch.empty(
(out_features),
dtype=torch.float32,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(self.precision)
return linear_forward_8da8w(
input,
self.weight,
self.scales,
self.zeros,
self.out_features,
self.precision,
)


#########################################################################
##### embedding table quantization ######

Expand Down

0 comments on commit c8267ba

Please sign in to comment.