Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Fix a special case where elements along the k dimension are repeated within each thread #5121

Merged
merged 21 commits into from
Nov 14, 2024

Conversation

Jokeren
Copy link
Contributor

@Jokeren Jokeren commented Nov 12, 2024

This PR includes the following changes:

  • Adds comprehensive tests for mixed-precision dot products, including configurations such as f8xf16, i8xf16, f8xf32, and i8xf32.
  • Fixes mmav2 when the k dimension contains duplicated elements. For example, with a 16x16 fp16 triton tensor (opidx=0, kwidth=4), a 16x32 tile is used, causing the first 16 elements in the k dimension to repeat in the last 16 elements. During mmav2 computation, only the first half is required.

@Jokeren Jokeren changed the title [BACKEND] Fix a special case where elements along the k dimension are repeated within each thread [DRAFT][BACKEND] Fix a special case where elements along the k dimension are repeated within each thread Nov 12, 2024
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
bool legacyLoweringIsBuggy = dot.getKWidth() >= 8;
bool legacyLoweringIsBuggy =
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's enable this path by default soon for anything other than ldmatrix

@Jokeren Jokeren changed the title [DRAFT][BACKEND] Fix a special case where elements along the k dimension are repeated within each thread [BACKEND] Fix a special case where elements along the k dimension are repeated within each thread Nov 12, 2024
@Jokeren Jokeren marked this pull request as ready for review November 12, 2024 22:35
@Jokeren Jokeren requested a review from ptillet as a code owner November 12, 2024 22:35
Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM provided the tests exercise this case.

for (size_t e = 0; e < numElemsPerVec; ++e) {
si.push_back(kRep * numElemsPerVec + tile * kWidth + e);
}
if (kIters <= repK) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nb. A way to simplify this logic is to invert the LL and then look at the register that holds the value for every top-left element of every tile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. That's something I'll try out.

python/test/regression/test_cast_matmul.py Show resolved Hide resolved
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for updating the tests to exercise this new path!

@lezcano lezcano merged commit 7f06338 into main Nov 14, 2024
7 checks passed
@lezcano lezcano deleted the keren/large-kwidth-fix branch November 14, 2024 08:58
@mobicham
Copy link

Could this possibly improve performance for this use-case? #4906 (comment)

@Jokeren
Copy link
Contributor Author

Jokeren commented Nov 15, 2024

I'm not sure. Feel free to try it out

@mobicham
Copy link

Currently triton built from the master branch is crashing with torch.compile that's why I asked. Will def try it once this is resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants