-
Notifications
You must be signed in to change notification settings - Fork 376
Move Int8DynamicActivationIntxWeightConfig out of experimental #1968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1968
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0763e90 with merge base 9516764 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c80e35e to
fb1eca3
Compare
f63ca78 to
90bc3aa
Compare
|
|
||
| @dataclass | ||
| class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andrewor14 can you have a look at this comment if there are any issues with it working well with QAT workflow with FakeQuantizeConfig.
andrewor14
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @metascroy, looks great overall. I pointed out some fields that appear to be different/missing in the comments but I think the new config will work well with QAT. Either way we'll probably need an end-to-end QAT test to confirm that prepare vs convert numerics match exactly (can be future PR). Also left some questions about the new layout.
torchao/quantization/quant_api.py
Outdated
| """ | ||
|
|
||
| weight_dtype: torch.dtype = torch.int8 | ||
| weight_granularity: Union[PerRow, PerGroup] = PerGroup(32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we should just make the type here Granularity and throw an error for unsupported types, so we don't tie ourselves to specific granularity in the signature itself
| if isinstance(weight_granularity, PerGroup): | ||
| group_size = weight_granularity.group_size | ||
| elif isinstance(weight_granularity, PerRow): | ||
| group_size = weight.shape[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, this seems more like per channel to me, which is expressed in terms of PerAxis. PerRow seems like an unrelated float8 thing according to the docstrings here:
ao/torchao/quantization/granularity.py
Line 74 in 5802d2d
| class PerRow(Granularity): |
| weight_granularity = config.weight_granularity | ||
| weight_zero_point_domain = config.weight_zero_point_domain | ||
| weight_mapping_type = config.weight_mapping_type | ||
| weight_scale_dtype = config.weight_scale_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a TODO for Int8DynamicActivationInt4WeightConfig to add scale dtype there as well?
|
|
||
|
|
||
| @dataclass | ||
| class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this match the numerics of Int8DynamicActivationInt4WeightConfig exactly if we choose weight_dtype = torch.int4? Is that a goal of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether or not this is the goal maybe we should document this somewhere, either here or in Int8DynamicActivationInt4WeightConfig's docstring, because users may be confused about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does match numerics exactly when weight_dtype = torch.int4
| weight_granularity: Union[PerRow, PerGroup] = PerGroup(32) | ||
| weight_zero_point_domain: ZeroPointDomain = ZeroPointDomain.NONE | ||
| weight_mapping_type: MappingType = MappingType.SYMMETRIC | ||
| weight_scale_dtype: Optional[torch.dtype] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice there's no weight_zero_point_dtype here. Is this assuming weight will always be symmetric? FWIW in the corresponding QAT FakeQuantizeConfig we do have zero_point_precision as well as scale_precision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Asymmetric is supported. weight_zero_point_dtype is set to torch.int8 if weight_zero_point_domain=ZeroPointDomain.INT, else it is None if weight_zero_point_domain=ZeroPointDomain.NONE.
torchao/quantization/quant_api.py
Outdated
| weight_mapping_type: MappingType = MappingType.SYMMETRIC | ||
| weight_scale_dtype: Optional[torch.dtype] = None | ||
| act_mapping_type: MappingType = MappingType.ASYMMETRIC | ||
| layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if we use this layout for cuda or other non-CPU backends? Are the numerics the same / is it still optimized? Does this also work with PlainLayout? Just wondering if this is the right default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works with QDQLayout, which subclasses PlainLayout(), but explicitly defines the linear impl.
Today this is done via a fallback path with PlainLayout() that @jerryzh168 mentioned might be removed.
We could make QDQLayout the default. PackedLinearInt8DynamicActivationIntxWeightLayout only works on CPU.
| preserve_zero=has_weight_zeros | ||
| or (weight_mapping_type == MappingType.SYMMETRIC), | ||
| zero_point_domain=weight_zero_point_domain, | ||
| _layout=QDQLayout(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we just ignore config.layout here? Or does that refer to activation layout only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment on line 775. The QDQLayout is used for quantization algorithm.
The packing for layout PackedLinearInt8DynamicActivationIntxWeightLayout is handled on block at 804.
|
@andrewor14 @jerryzh168 any more concerns on this PR? |
|
|
||
|
|
||
| class _AffineQuantizedTensor(AffineQuantizedTensor): | ||
| def make_packed_linear_int8_dynamic_activation_intx_weight_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not for this PR, but would it make sense to have a separate tensor subclass for this layout and it can inherit from AQT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps that would be OK, but I'm wondering how that intersects with the quantize_ config (which has layout). Note that this layout does not override anything in AQT, so it doesn't need to be a separate class.
I just wanted a way for users to construct this AQT from plain data, which is kind of useful if people want to use them outside of the quantize_ API. There is no easy, packaged up way for users to do that with AQT right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps that would be OK, but I'm wondering how that intersects with the quantize_ config (which has layout).
you can continue to use layout as an abstraction if you feel that is useful
Note that this layout does not override anything in AQT, so it doesn't need to be a separate class.
OK then seems to be OK to keep using AQT
I just wanted a way for users to construct this AQT from plain data, which is kind of useful if people want to use them outside of the quantize_ API. There is no easy, packaged up way for users to do that with AQT right now.
it seems that using default AQT constructor is fine for now? or do you feel we should add another constructor function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or do you feel we should add another constructor function
let's wait and see if others want that functionality. The current way to construct an AQT from plain would be something like:
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=target)
tensor_impl = PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl.from_plain(
int_data, scale, zero_point, layout, bias
)
aqt = AffineQuantizedTensor(
tensor_impl,
block_size=(1, group_size),
shape=int_data.shape,
quant_min=qmin,
quant_max=qmax,
zero_point_domain=ZeroPointDomain.INT
if has_weight_zeros
else ZeroPointDomain.NONE,
)
which is isn't the most intuitive IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this actually looks reasonable to me, similar to torch.Tensor where people can construct a TensorImpl and pass it to Tensor constructor
| for weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)]: | ||
| for has_weight_zeros in [True, False]: | ||
| for has_bias in [True, False]: | ||
| idx = len(layers) | ||
| layer_to_weight_dtype[idx] = weight_dtype | ||
| layer_to_has_weight_zeros[idx] = has_weight_zeros | ||
| layers.append(torch.nn.Linear(64, 64, bias=has_bias)) | ||
| activations = torch.randn(2, 1, 64, dtype=torch.float32) | ||
| for weight_mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use
| for algo, layer_size, input_shape, high_precision_dtype in itertools.product( |
jerryzh168
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
also labeling this as a new feature, might be helpful to write down how to use the API in summary so we can copy paste for release notes |
|
Looks great! @metascroy can you add a unit test in a separate PR showing this new config matches Int8DynamicActivationInt4WeightConfig numerics exactly if |
* up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up
This PR moves Int8DynamicActivationIntxWeightConfig and its quantizer into torchao.quantization.quant_api (out of experimental). Int8DynamicActivationIntxWeightConfig is refactored to closely mirror Int8DynamicActivationInt4WeightConfig when weight_dtype=torch.int4, layout=QDQLayout().
Quantization in Int8DynamicActivationIntxWeightConfig is done with QDQLayout, and then packing is done separately with make_packed_linear_int8_dynamic_activation_intx_weight_tensor. This is to separate the quantization algorithm from the storage. Both the packed and QDQLayout quantize with the same algorithm, and this is made explicit.
Example API usage: