Skip to content

mx: triton kernel to cast to mx and write in col-major #1932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 26, 2025
Merged

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Mar 21, 2025

Summary:

Implements a triton kernel for a cast to mxfp8 from a row-major input across dim1, which is 3.5x to 4.5x faster than what compile can generate today. Note that this is a prototype kernel, and I expect to (a) improve it in future PRs and (b) delete it in ~weeks when we have compile support for this.

An integration into MXLinear will follow in a separate PR.

Example of tiling (simplified for small example size):

        Example tiling for n_rows==8, n_cols=8, ROW_TILE_SIZE=4, COL_TILE_SIZE=4, INNER_BLOCK_SIZE=2,
        pid_row=0, pid_col=0:

        Input (row-major)

        cols      0  1  2  3  4  5  6  7
        --------------------------------
        rows 0 |  0  1  2  3
             1 |  8  9 10 11
             2 | 16 17 18 19
             3 | 24 25 26 27
             4 |
             5 |
             6 |
             7 |

        Output (row-major of transpose), ids are from input

        cols      0  1  2  3  4  5  6  7
        --------------------------------
        rows 0 |  0  8 16 24
             1 |  1  9 17 25
             2 |  2 10 18 26
             3 |  3 11 19 27
             4 |
             5 |
             6 |
             7 |

        Output (scales), s(0, 8) means the scale used to cast elements 0 and 8

        rows           0          1  ...      4  ...       31
        ------------------------------------------------------
                  s(0, 8)  s(16, 24) ... s(1, 9) ... s(19, 27)

Test Plan:

// tests pass
pytest test/prototype/mx_formats/test_custom_cast.py -s -x -k triton_mxfp8_dim1

// performance compile vs triton: https://www.internalfb.com/phabricator/paste/view/P1762691809
// * 4k by 4k tensor: about a 3.6x speedup
// * 16k by 16k tensor: about a 4.7x speedup

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Mar 21, 2025

Copy link

pytorch-bot bot commented Mar 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1932

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 7ecd79f with merge base 3fb1665 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4c77bd3
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 21, 2025
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 18073c5
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c2d2aed
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 91248bc
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3a53ae7
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 76fef6c
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8486913
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c44a9db
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d590401
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 26e84f9
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
@vkuzo vkuzo changed the title [wip] triton kernel to cast to mx and write in col-major mx: triton kernel to cast to mx and write in col-major Mar 21, 2025
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 95105a6
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
@vkuzo vkuzo added the topic: performance Use this tag if this PR improves the performance of a feature label Mar 21, 2025
@eellison eellison self-requested a review March 21, 2025 16:47
# example transformation (specifics depend on tile sizes):
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
col_scale_indices = col_scale_indices + (
tl.floor(col_scale_indices / BLOCKS_PER_ROW_TILE) * jump_vals_per_col
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should just be doing integer division instead of floor + /

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good ! next step: compile to generate this

).to(tl.int32)

# TODO(future): mask this store
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the launcher, should we assert divisibility of block sizes, so we hard error for this case ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for now I hackly assert that on L1319:L1322

)

return (
output_col_major.t(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since only the data_ptr of output_col_major is used when you pass it into triton, you could initialize it with the correct strides


return (
output_col_major.t(),
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here..

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 24, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d1c1db7
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 24, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6c795ce
ghstack-comment-id: 2743450537
Pull Request resolved: #1932
@vkuzo vkuzo merged commit d32afef into main Mar 26, 2025
49 of 50 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: performance Use this tag if this PR improves the performance of a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants