Skip to content

Conversation

@metascroy
Copy link
Contributor

@metascroy metascroy commented Mar 26, 2025

This PR moves Int8DynamicActivationIntxWeightConfig and its quantizer into torchao.quantization.quant_api (out of experimental). Int8DynamicActivationIntxWeightConfig is refactored to closely mirror Int8DynamicActivationInt4WeightConfig when weight_dtype=torch.int4, layout=QDQLayout().

Quantization in Int8DynamicActivationIntxWeightConfig is done with QDQLayout, and then packing is done separately with make_packed_linear_int8_dynamic_activation_intx_weight_tensor. This is to separate the quantization algorithm from the storage. Both the packed and QDQLayout quantize with the same algorithm, and this is made explicit.

Example API usage:

quantize_(
      model,
      Int8DynamicActivationIntxWeightConfig(
          weight_dtype=torch.int4,
          weight_granularity=PerGroup(32),
          weight_mapping_type=MappingType.ASYMMETRIC,
          weight_zero_point_domain=ZeroPointDomain.NONE,
          layout=QDQLayout(),
      ),
  )

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1968

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0763e90 with merge base 9516764 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2025
@metascroy metascroy changed the title Create utility to construct QuantizedLinear from plain data Move Int8DynamicActivationIntxWeightConfig out of experimental Apr 1, 2025
@metascroy metascroy requested a review from jerryzh168 April 1, 2025 03:43
@metascroy metascroy force-pushed the add-torchao-module-constructors branch 2 times, most recently from c80e35e to fb1eca3 Compare April 6, 2025 00:02
@metascroy metascroy force-pushed the add-torchao-module-constructors branch from f63ca78 to 90bc3aa Compare April 6, 2025 04:25

@dataclass
class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewor14 can you have a look at this comment if there are any issues with it working well with QAT workflow with FakeQuantizeConfig.

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @metascroy, looks great overall. I pointed out some fields that appear to be different/missing in the comments but I think the new config will work well with QAT. Either way we'll probably need an end-to-end QAT test to confirm that prepare vs convert numerics match exactly (can be future PR). Also left some questions about the new layout.

"""

weight_dtype: torch.dtype = torch.int8
weight_granularity: Union[PerRow, PerGroup] = PerGroup(32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should just make the type here Granularity and throw an error for unsupported types, so we don't tie ourselves to specific granularity in the signature itself

if isinstance(weight_granularity, PerGroup):
group_size = weight_granularity.group_size
elif isinstance(weight_granularity, PerRow):
group_size = weight.shape[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, this seems more like per channel to me, which is expressed in terms of PerAxis. PerRow seems like an unrelated float8 thing according to the docstrings here:

class PerRow(Granularity):

weight_granularity = config.weight_granularity
weight_zero_point_domain = config.weight_zero_point_domain
weight_mapping_type = config.weight_mapping_type
weight_scale_dtype = config.weight_scale_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO for Int8DynamicActivationInt4WeightConfig to add scale dtype there as well?



@dataclass
class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this match the numerics of Int8DynamicActivationInt4WeightConfig exactly if we choose weight_dtype = torch.int4? Is that a goal of this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether or not this is the goal maybe we should document this somewhere, either here or in Int8DynamicActivationInt4WeightConfig's docstring, because users may be confused about this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does match numerics exactly when weight_dtype = torch.int4

weight_granularity: Union[PerRow, PerGroup] = PerGroup(32)
weight_zero_point_domain: ZeroPointDomain = ZeroPointDomain.NONE
weight_mapping_type: MappingType = MappingType.SYMMETRIC
weight_scale_dtype: Optional[torch.dtype] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice there's no weight_zero_point_dtype here. Is this assuming weight will always be symmetric? FWIW in the corresponding QAT FakeQuantizeConfig we do have zero_point_precision as well as scale_precision

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asymmetric is supported. weight_zero_point_dtype is set to torch.int8 if weight_zero_point_domain=ZeroPointDomain.INT, else it is None if weight_zero_point_domain=ZeroPointDomain.NONE.

weight_mapping_type: MappingType = MappingType.SYMMETRIC
weight_scale_dtype: Optional[torch.dtype] = None
act_mapping_type: MappingType = MappingType.ASYMMETRIC
layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if we use this layout for cuda or other non-CPU backends? Are the numerics the same / is it still optimized? Does this also work with PlainLayout? Just wondering if this is the right default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works with QDQLayout, which subclasses PlainLayout(), but explicitly defines the linear impl.

Today this is done via a fallback path with PlainLayout() that @jerryzh168 mentioned might be removed.

We could make QDQLayout the default. PackedLinearInt8DynamicActivationIntxWeightLayout only works on CPU.

preserve_zero=has_weight_zeros
or (weight_mapping_type == MappingType.SYMMETRIC),
zero_point_domain=weight_zero_point_domain,
_layout=QDQLayout(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we just ignore config.layout here? Or does that refer to activation layout only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment on line 775. The QDQLayout is used for quantization algorithm.

The packing for layout PackedLinearInt8DynamicActivationIntxWeightLayout is handled on block at 804.

@metascroy
Copy link
Contributor Author

@andrewor14 @jerryzh168 any more concerns on this PR?



class _AffineQuantizedTensor(AffineQuantizedTensor):
def make_packed_linear_int8_dynamic_activation_intx_weight_tensor(
Copy link
Contributor

@jerryzh168 jerryzh168 Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this PR, but would it make sense to have a separate tensor subclass for this layout and it can inherit from AQT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps that would be OK, but I'm wondering how that intersects with the quantize_ config (which has layout). Note that this layout does not override anything in AQT, so it doesn't need to be a separate class.

I just wanted a way for users to construct this AQT from plain data, which is kind of useful if people want to use them outside of the quantize_ API. There is no easy, packaged up way for users to do that with AQT right now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps that would be OK, but I'm wondering how that intersects with the quantize_ config (which has layout).

you can continue to use layout as an abstraction if you feel that is useful

Note that this layout does not override anything in AQT, so it doesn't need to be a separate class.

OK then seems to be OK to keep using AQT

I just wanted a way for users to construct this AQT from plain data, which is kind of useful if people want to use them outside of the quantize_ API. There is no easy, packaged up way for users to do that with AQT right now.

it seems that using default AQT constructor is fine for now? or do you feel we should add another constructor function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or do you feel we should add another constructor function

let's wait and see if others want that functionality. The current way to construct an AQT from plain would be something like:

layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=target)
tensor_impl = PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl.from_plain(
        int_data, scale, zero_point, layout, bias
)
aqt = AffineQuantizedTensor(
    tensor_impl,
    block_size=(1, group_size),
    shape=int_data.shape,
    quant_min=qmin,
    quant_max=qmax,
    zero_point_domain=ZeroPointDomain.INT
    if has_weight_zeros
    else ZeroPointDomain.NONE,
)

which is isn't the most intuitive IMO.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this actually looks reasonable to me, similar to torch.Tensor where people can construct a TensorImpl and pass it to Tensor constructor

Comment on lines 32 to 33
for weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)]:
for has_weight_zeros in [True, False]:
for has_bias in [True, False]:
idx = len(layers)
layer_to_weight_dtype[idx] = weight_dtype
layer_to_has_weight_zeros[idx] = has_weight_zeros
layers.append(torch.nn.Linear(64, 64, bias=has_bias))
activations = torch.randn(2, 1, 64, dtype=torch.float32)
for weight_mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use

for algo, layer_size, input_shape, high_precision_dtype in itertools.product(
to reduce indentation

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jerryzh168 jerryzh168 added the topic: new feature Use this tag if this PR adds a new feature label Apr 8, 2025
@jerryzh168
Copy link
Contributor

also labeling this as a new feature, might be helpful to write down how to use the API in summary so we can copy paste for release notes

@metascroy metascroy merged commit 04d1186 into main Apr 8, 2025
18 of 19 checks passed
@andrewor14
Copy link
Contributor

Looks great! @metascroy can you add a unit test in a separate PR showing this new config matches Int8DynamicActivationInt4WeightConfig numerics exactly if weight_dtype = torch.int4?

liangel-02 pushed a commit that referenced this pull request Aug 25, 2025
* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up

* up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants