Skip to content

Relax dtype requirements for int4 and float8 quants in autoquant #1571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 17, 2025

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Jan 16, 2025

Summary:
Some of the int4 quant and fp8 only works with bfloat16/float16, previously we require the model to be in correct dtype to apply these in autoquant, this PR relaxes the constraints by converting weight, bias and activation to compatible dtypes

Test Plan:
python test/integration/test_integration.py -k test_autoquant_int4wo

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:
Some of the int4 quant only works with bfloat16/float16, previously we require
the model to be in correct dtype to apply these in autoquant, this PR relaxes the constraints by
converting weight and activation to compatible dtypes

Test Plan:
python test/integration/test_integration.py -k test_autoquant_int4wo

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Jan 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1571

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 3 Pending

As of commit c03760d with merge base e1cb44a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 16, 2025
@jerryzh168
Copy link
Contributor Author

@jcaip I can't run sparse marlin kernel locally: Could not run 'torchao::marlin _24_gemm' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). how do we compile to get this op? right now I'm using python setup.py develop

@jerryzh168 jerryzh168 added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) autoquant labels Jan 16, 2025
@@ -227,6 +227,11 @@ def from_plain(
# Linear layers are (in_features, out_features) but the int_data that is reaching this point
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
q_w_24 = int_data.t()
# addressing the case when scale has dimension 1, happens when
# weight_shape[-1] == group_size == 128
if scale.ndim == 1:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also when there is one dimension that is 1, we'll do a squeeze, so the scale will become 1-d, this is a corner case
cc @jcaip @Diogo-V let me know if you have other ideas for the fix, alternatively we can pass around block_size here as well, or we can store group_size in _layout maybe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Jerry, I haven’t been actively working on this repo lately as I have been more focused on work, so I am not entirely up to date. That said, please keep in mind that my input might not fully reflect the current state of things.

My two cents here is that this should work well, as the condition seems to allow for generalization to other scenarios that might produce that corner case in the future i.e: if weight_shape[-1] == group_size == 64 yields a scale with dimension 1.

@jcaip
Copy link
Contributor

jcaip commented Jan 16, 2025

@jcaip I can't run sparse marlin kernel locally: Could not run 'torchao::marlin _24_gemm' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). how do we compile to get this op? right now I'm using python setup.py develop

You need to build with USE_CPP=1 and it should show up

@jerryzh168 jerryzh168 changed the title Relax dtype requirements for int4 quants in autoquant Relax dtype requirements for int4 and float8 quants in autoquant Jan 16, 2025
@jerryzh168 jerryzh168 merged commit cf45336 into pytorch:main Jan 17, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autoquant CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants