-
Notifications
You must be signed in to change notification settings - Fork 273
triton kernel to cast to mx across dim0 and dim1 #1869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1869
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3f33752 with merge base ddb7f83 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
grid_cols = triton.cdiv(n_cols, col_tile_size) | ||
|
||
# inner_block_size = 32 | ||
rename_me_tile_size = row_tile_size * col_tile_size // inner_block_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol i wasnt sure what to call this either..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good ! did not review indexing math too closely but hoepfully will be replaced with torch.compile soon.
output_col_major.t(), | ||
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the triton is just accessing data_ptr
you could initialize these as the correct layout if you wanted.
|
||
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] | ||
col_scale_indices = col_scale_indices + ( | ||
tl.floor(col_scale_indices / factor) * jump_vals_per_col |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason to not just use integer division here instead of /
and floor ?
# TODO(future): mask | ||
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add some runtime asserts about block divisibility ?
# TODO 1: rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1908/files | ||
# before: 1.7 TB/s | ||
# after: ? | ||
scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suspect theres a better way of getting the scale avoiding doing a float log. i guess this is what the TODO is for.
closing this in favor of #1932, I split the dim1 logic so we can have this PR as an easy historical reference |
Summary:
This is a start of the short term workaround for pytorch/pytorch#148682 . tl;dr; is that today torch.compile is not yet generating good fusions for the cast to mx across dim0 and dim1, so we can add custom kernels to do this until the torch.compile support is better. We take a simple strategy to minimize overall reads and writes: read the input in 32x32 tiles, and write out the row-major and col-major casted values and scales directly from each tile.
The kernel added in this PR is ok-ish: it's ~2x faster than torch.compile across a range of shapes with MKN from 1024 to 16384, but it only hits around (redacted) 8 TBPS peak mem bandwidth. Impact on e2e benchmarks of fwd+bwd of MXLinear:
Some future improvements we can make:
For now I want to check this in as just a kernel + benchmark without an integration to
MXLinear
. If we make this kernel significantly faster, we can integrate.Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: