- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.7k
Fix Triton GEMM templates with k=1 #158650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
| 🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158650
 Note: Links to docs will display an error until the docs builds have been completed. ⏳ 79 Pending, 2 Unrelated FailuresAs of commit 1fdaf57 with merge base 036eb1f ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures 
 
 UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
 
 This comment was automatically generated by Dr. CI and updates every 15 minutes. | 
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Thanks to davidberard98 for much of the analysis here. For GEMMs of K=1, the hints, `tl.multiple_of` and `tl.max_contiguous` apply completely, as the indices to the loads are only dependent on `offs_m` and `offs_n`. For shapes like `(97x1), (1x97)`, this results in misaligned address errors, due to the fact that for all BLOCK_M and BLOCK_N sizes, the last tile is not a contiguous load. With K > 1 case, the hint is not as strict given the dependency on the k indices for the load as well. In the K=1 case, only `offs_m` and `offs_n` are used and broadcasted to the index shape. For nice shapes with K=1, where M, N are a multiple 8 to where these hints are fine and there is no misaligned address, there is no performance regression observed on H100: <img width="547" height="402" alt="Screenshot 2025-07-18 at 5 05 47 PM" src="https://github.com/user-attachments/assets/fee2bbaa-784c-422e-bb8c-43c6c2607ad2" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Thanks to davidberard98 for much of the analysis here. For GEMMs of K=1, the hints, `tl.multiple_of` and `tl.max_contiguous` apply completely, as the indices to the loads are only dependent on `offs_m` and `offs_n`. For shapes like `(97x1), (1x97)`, this results in misaligned address errors, due to the fact that for all BLOCK_M and BLOCK_N sizes, the last tile is not a contiguous load. With K > 1 case, the hint is not as strict given the dependency on the k indices for the load as well. In the K=1 case, only `offs_m` and `offs_n` are used and broadcasted to the index shape. One can say these hints are "wrong", but in various cases in the hints being wrong, such as with the shape `9999x4, 4x9999`, there is a substantial performance improvement with the hint. For nice shapes with K=1, where M, N are a multiple 8 to where these hints are fine and there is no misaligned address, there is no performance regression observed on H100: <img width="547" height="402" alt="Screenshot 2025-07-18 at 5 05 47 PM" src="https://github.com/user-attachments/assets/fee2bbaa-784c-422e-bb8c-43c6c2607ad2" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Thanks to davidberard98 for much of the analysis here. For GEMMs of K=1, the hints, `tl.multiple_of` and `tl.max_contiguous` apply completely, as the indices to the loads are only dependent on `offs_m` and `offs_n`. For shapes like `(97x1), (1x97)`, this results in misaligned address errors, due to the fact that for all BLOCK_M and BLOCK_N sizes, the last tile is not a contiguous load. With K > 1 case, the hint is not as strict given the dependency on the k indices for the load as well. In the K=1 case, only `offs_m` and `offs_n` are used and broadcasted to the index shape. One can say these hints are "wrong", but in various cases in the hints being wrong, such as with the shape `9999x4, 4x9999`, there is a substantial performance improvement with the hint. For nice shapes with K=1, where M, N are a multiple 8 to where these hints are fine and there is no misaligned address, there is no performance regression observed on H100: <img width="547" height="402" alt="Screenshot 2025-07-18 at 5 05 47 PM" src="https://github.com/user-attachments/assets/fee2bbaa-784c-422e-bb8c-43c6c2607ad2" /> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
| rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | ||
| rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | ||
| if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and M >= BLOCK_M: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind clarifying For GEMMs of K=1, the hints, tl.multiple_of and tl.max_contiguous apply completely, as the indices to the loads are only dependent on offs_m and offs_n  - if they apply completely, why are we skipping the check? or maybe i'm misunderstanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For K != 1, the inductor-generated code in the k loop looks like this:
        a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K)
        b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K)
        a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K)
        b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K)
        idx_m = offs_a_m[:, None]
        idx_n = a_k_idx_vals
        xindex = idx_n + 3*idx_m
        a = tl.load(A + (xindex), mask=a_mask, other=0.0)
        idx_m = b_k_idx_vals
        idx_n = offs_b_n[None, :]
        xindex = idx_n + 97*idx_m
        b = tl.load(B + (xindex), mask=b_mask, other=0.0)
So if we're loading a block of data [BLOCK_K, BLOCK_N], then although we assume contiguity/divisibility of indices[:, 0] (due to the hints), we can't assume divisibility of indices[:, 1] due to the addition of the idx_m * 97, where 97 has no divisibility properties - therefore no vectorization is performed.
But for K == 1, the inductor generated code looks like this
        xindex = idx_m + idx_n
        a = tl.load(A + ((tl.broadcast_to(idx_m, xindex.shape)).broadcast_to(xindex.shape)), mask=a_mask, other=0.0)
        idx_m = b_k_idx_vals
        idx_n = offs_b_n[None, :]
        xindex = idx_n + 97*idx_m
        b = tl.load(B + ((tl.broadcast_to(idx_n, xindex.shape)).broadcast_to(xindex.shape)), mask=b_mask, other=0.0)Here, we explicitly remove the indexing in the k dimension (as we should, since there's no point in loading a BLOCK_K x BLOCK_M and masking out (BLOCK_K-1) rows, if K = 1). As a result, there's nothing hinting that vectorization is not possible, so Triton vectorizes the loads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: the real problem is that the divisibility and contiguity hints are wrong. But they do provide perf improvement (I don't fully understand why) and usually don't cause incorrect behavior or crashes.
So this PR is helpful by preventing many of the crashes while preserving perf on the cases where the hints help.
| @pytorchbot merge | 
| Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team | 
| @pytorchbot revert -c ghfirst -m "Sorry but this is breaking internally, see D78805560 for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts" | 
| @pytorchbot successfully started a revert job. Check the current status here. | 
This reverts commit 9df0f56. Reverted #158650 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally, see D78805560 for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](#158650 (comment)))
| @PaulZhang12 your PR has been successfully reverted. | 
| @pytorchbot merge -f "Turns out the internal signal was flaky" | 
| Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).  Please use  Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team | 
Stack from ghstack (oldest at bottom):
Thanks to @davidberard98 for much of the analysis here. For GEMMs of K=1, the hints,
tl.multiple_ofandtl.max_contiguousapply completely, as the indices to the loads are only dependent onoffs_mandoffs_n. For shapes like(97x1), (1x97), this results in misaligned address errors, due to the fact that for all BLOCK_M and BLOCK_N sizes, the last tile is not a contiguous load. With K > 1 case, the hint is not as strict given the dependency on the k indices for the load as well. In the K=1 case, onlyoffs_mandoffs_nare used and broadcasted to the index shape.One can say these hints are "wrong", but in various cases in the hints being wrong, such as with the shape
9999x4, 4x9999, there is a substantial performance improvement with the hint.For nice shapes with K=1, where M, N are a multiple 8 to where these hints are fine and there is no misaligned address, there is no performance regression observed on H100:

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben