-
Notifications
You must be signed in to change notification settings - Fork 273
mx: triton kernel to cast to mx and write in col-major #1932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1932
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 7ecd79f with merge base 3fb1665 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# example transformation (specifics depend on tile sizes): | ||
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] | ||
col_scale_indices = col_scale_indices + ( | ||
tl.floor(col_scale_indices / BLOCKS_PER_ROW_TILE) * jump_vals_per_col |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we should just be doing integer division instead of floor + /
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good ! next step: compile to generate this
).to(tl.int32) | ||
|
||
# TODO(future): mask this store | ||
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the launcher, should we assert divisibility of block sizes, so we hard error for this case ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, for now I hackly assert that on L1319:L1322
) | ||
|
||
return ( | ||
output_col_major.t(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since only the data_ptr of output_col_major
is used when you pass it into triton, you could initialize it with the correct strides
|
||
return ( | ||
output_col_major.t(), | ||
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same thing here..
Summary:
Implements a triton kernel for a cast to mxfp8 from a row-major input across dim1, which is 3.5x to 4.5x faster than what compile can generate today. Note that this is a prototype kernel, and I expect to (a) improve it in future PRs and (b) delete it in ~weeks when we have compile support for this.
An integration into
MXLinear
will follow in a separate PR.Example of tiling (simplified for small example size):
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: