-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND] Fix a special case where elements along the k dimension are repeated within each thread #5121
Conversation
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) { | ||
bool legacyLoweringIsBuggy = dot.getKWidth() >= 8; | ||
bool legacyLoweringIsBuggy = | ||
kWidth >= 8 || (kWidth == 4 && bitwidth == 32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's enable this path by default soon for anything other than ldmatrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM provided the tests exercise this case.
for (size_t e = 0; e < numElemsPerVec; ++e) { | ||
si.push_back(kRep * numElemsPerVec + tile * kWidth + e); | ||
} | ||
if (kIters <= repK) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nb. A way to simplify this logic is to invert the LL and then look at the register that holds the value for every top-left element of every tile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. That's something I'll try out.
…o keren/large-kwidth-fix
…o keren/large-kwidth-fix
…o keren/large-kwidth-fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for updating the tests to exercise this new path!
Could this possibly improve performance for this use-case? #4906 (comment) |
I'm not sure. Feel free to try it out |
Currently triton built from the master branch is crashing with torch.compile that's why I asked. Will def try it once this is resolved. |
This PR includes the following changes: