Skip to content

skip quant/dequant decomposed #2299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from

Conversation

shiyang-weng
Copy link
Contributor

@shiyang-weng shiyang-weng commented Jun 4, 2025

Fix #2228

What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph.
However we met problems with these q/dq ops both in the PyTorch core and Torchao.

PyTorch core:

The quantize_per_tensor op does not support FP8. We want to fix it via pytorch/pytorch#153601. And as you commented, the op is deprecated.
Torchao:

In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor:
https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1
After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now.
For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because
It is an op from Torchao, which is unknown to the constant folder
It is decomposed to smaller ops, so we cannot put it in the list as a single op.

  1. Aligned dispatch_key with pt. quantize_affine and dequantize_affine will not be decomposed
  2. Register meta func for q/dq
  3. Register dq for uintx. Previously, dequant would be decomposed into small ops without the need for registration.

Copy link

pytorch-bot bot commented Jun 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2299

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Cancelled Job

As of commit 0eedbf3 with merge base 1239842 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

Hi @shiyang-weng!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@shiyang-weng shiyang-weng marked this pull request as draft June 4, 2025 08:27


_quantize_affine, _quantize_affine_meta = register_custom_op_with_meta(
_quantize_affine, _quantize_affine_meta)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we separate the registrations for meta and non-meta device and also use them as decorators instead calling it? Then it would be better aligned with the rest of the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we separate the registrations for meta and non-meta device and also use them as decorators instead calling it? Then it would be better aligned with the rest of the file.

Done

@shiyang-weng shiyang-weng changed the title [WIP] skip quant/dequant decomposed skip quant/dequant decomposed Jun 9, 2025
@shiyang-weng shiyang-weng marked this pull request as ready for review June 9, 2025 08:14

lib_namespace = lib.ns
op = getattr(getattr(torch.ops, lib_namespace), op_name)
register_decomposition([op])(fn)
if dispatch_key == "CompositeImplicitAutograd":
register_decomposition([op])(fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feeling adding a flag in _register_custom_op to indicate whether this op should be decomposed will be clearer, cc @jerryzh168 what's your suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I feel adding a flag might be easier here

@implements(torch.ops.torchao.dequantize_affine)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func.default, # work around
Copy link
Contributor Author

@shiyang-weng shiyang-weng Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better way may be

@implements(torch.ops.torchao.dequantize_affine.default)
def _(func, types, args, kwargs):
    return return_and_correct_aliasing(
        func, 
        args,
        kwargs,
        _dequantize_affine_impl(*args),
    )

But there will be following issue. NotImplementedError: UintxTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.to', overload='dtype')>, types=(<class 'torchao.dtypes.uintx.uintx_layout.UintxTensor'>,)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you using uintx layout for float8? why is this changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you using uintx layout for float8? why is this changed?

dequantize_affine/quantize_affine also used in uintx.
Without this patch, dequant would be decomposed, so it's ok for uintx, but now it is not decomposed. It is necessary to be found on dispatch__torch_function_.
So I use _implements to add it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there will be following issue. NotImplementedError: UintxTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.to', overload='dtype')>, types=(<class 'torchao.dtypes.uintx.uintx_layout.UintxTensor'>,)

for this, you could implement to in the UintxTensor as well I think

@shiyang-weng
Copy link
Contributor Author

Remove this pr still has this accuracy error.
So this error may not be caused by this patch

$ conda list |grep torch
pytorch-triton            3.3.1+gitc8757738          pypi_0    pypi
torch                     2.8.0.dev20250608+cu126          pypi_0    pypi
torchao                   0.12.0+git3aa9361           dev_0    <develop>
torchaudio                2.8.0.dev20250609+cu126          pypi_0    pypi
torchvision               0.23.0.dev20250609+cu126          pypi_0    pypi
$ git log -1
commit 3aa93619466739c9d9845e1db3bfb2ff0f464857 (HEAD, origin/main, origin/HEAD, main)
$ rm -rf /tmp/torchinductor_pt-gpu/ && /data/wengshiy/conda/envs/ao/bin/pytest -s test/integration/test_integration.py::TestSubclass
       if q.is_Number and p.is_Number:
>           assert p >= 0, p
E           torch._inductor.exc.InductorError: AssertionError: -420769046757011/1000000000000000
E           
../conda/envs/ao/lib/python3.10/site-packages/torch/utils/_sympy/functions.py:488: InductorError

@shiyang-weng
Copy link
Contributor Author

Could you help reivew this pr? @jerryzh168
Looks CI errors are unrelated.

And I need to fill out CLA, do you know who I can fill out on the "point of contact"?

Comment on lines +876 to +877
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
Copy link
Contributor

@jerryzh168 jerryzh168 Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember now we have quantize_affine_float8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this patch, dequantize_affine/quantize_affine will be decomposed both on int8 and fp8.
If you think it's better to handle fp8 operators separately, I will create another pr to register fp8 operators separately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use quantize_affine_float8 / dequantize_affine_float8 for float8 I think

@jerryzh168
Copy link
Contributor

for CLA I think you can just follow the process in the CI:

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [cla@meta.com](mailto:cla@meta.com?subject=CLA%20for%20pytorch%2Fao%20%232299). Thanks!

@@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int:
return n + k - (n % k)


def _register_custom_op(lib):
def _register_custom_op(lib, dispatch_key="CompositeImplicitAutograd"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you planning to change dispatch key to a flag of whether to decompose or not

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 19, 2025
@shiyang-weng
Copy link
Contributor Author

Fixed on #2379

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Quant] Can quant not be decomposed on inductor?
5 participants