Skip to content

[scaled grouped mm] integrate triton kernels into differentiable scaled grouped mm #2077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 22, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Apr 18, 2025

Prior PR in stack: #2064

Summary

Performance

TL;DR there is ~1.25x - 37x speedup using the triton kernels (most shapes speedups fell between 2x-6x).

CPU loop:

A_shape        B_shape           high_precision_dtype      time_us
-------------  ----------------  ----------------------  ---------
(256, 4096)    (4, 4096, 4096)   torch.bfloat16            4324.35
(256, 4096)    (8, 4096, 4096)   torch.bfloat16            8197.2
(256, 4096)    (16, 4096, 4096)  torch.bfloat16           15830.9
(4096, 4096)   (4, 4096, 4096)   torch.bfloat16            5211.11
(4096, 4096)   (8, 4096, 4096)   torch.bfloat16            9003.3
(4096, 4096)   (16, 4096, 4096)  torch.bfloat16           16720.7
(65536, 4096)  (4, 4096, 4096)   torch.bfloat16           31257
(65536, 4096)  (8, 4096, 4096)   torch.bfloat16           34253.2
(65536, 4096)  (16, 4096, 4096)  torch.bfloat16           40141.3

Triton kernels:

A_shape        B_shape           high_precision_dtype      time_us
-------------  ----------------  ----------------------  ---------
(256, 4096)    (4, 4096, 4096)   torch.bfloat16            835.657
(256, 4096)    (8, 4096, 4096)   torch.bfloat16            835.657
(256, 4096)    (16, 4096, 4096)  torch.bfloat16            832.382
(4096, 4096)   (4, 4096, 4096)   torch.bfloat16            830.429
(4096, 4096)   (8, 4096, 4096)   torch.bfloat16           4666.3
(4096, 4096)   (16, 4096, 4096)  torch.bfloat16           7502.13
(65536, 4096)  (4, 4096, 4096)   torch.bfloat16            840.335
(65536, 4096)  (8, 4096, 4096)   torch.bfloat16          26359.4
(65536, 4096)  (16, 4096, 4096)  torch.bfloat16          27910.4

Copy link

pytorch-bot bot commented Apr 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2077

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (6 Unrelated Failures)

As of commit 6821e44 with merge base 9af2a45 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 18, 2025
@danielvegamyhre danielvegamyhre added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) topic: performance Use this tag if this PR improves the performance of a feature and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Apr 18, 2025
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 18, 2025
@danielvegamyhre danielvegamyhre requested a review from drisspg April 18, 2025 16:20
lint

update docstrings

add bench script

add bench script

bench against compile

comment

clean up

fix masks

lint

integrate triton kernels into scaled grouped mm

lint
@danielvegamyhre
Copy link
Contributor Author

Dr CI confirmed test failures are unrelated to this change (and I manually confirmed as well, they are QAT related)

@danielvegamyhre danielvegamyhre merged commit b8206d7 into main Apr 22, 2025
12 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) topic: performance Use this tag if this PR improves the performance of a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants