Skip to content

triton kernel to cast to mx across dim0 and dim1 #1869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Mar 11, 2025

Summary:

This is a start of the short term workaround for pytorch/pytorch#148682 . tl;dr; is that today torch.compile is not yet generating good fusions for the cast to mx across dim0 and dim1, so we can add custom kernels to do this until the torch.compile support is better. We take a simple strategy to minimize overall reads and writes: read the input in 32x32 tiles, and write out the row-major and col-major casted values and scales directly from each tile.

The kernel added in this PR is ok-ish: it's ~2x faster than torch.compile across a range of shapes with MKN from 1024 to 16384, but it only hits around (redacted) 8 TBPS peak mem bandwidth. Impact on e2e benchmarks of fwd+bwd of MXLinear:

# K = 1024: speedup over bf16 0.19 -> 0.28
# K = 16384: speedup over bf16 0.91 -> 1.07

Some future improvements we can make:

  1. add an outer_tile and set it to 128, keeping inner tile at 32 (see below, will likely get us to redacted of peak bandwidth). 128 because outer_tile 256+ seems to run out of shared memory on B200.
  2. write out swizzled scales (not currently in this benchmark, but will help with e2e)
  3. see if we can make the col-major output and row-major and col-major writes coalesced (didn't do any research on how yet)

For now I want to check this in as just a kernel + benchmark without an integration to MXLinear. If we make this kernel significantly faster, we can integrate.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Mar 11, 2025

Stack from ghstack (oldest at bottom):

Copy link

pytorch-bot bot commented Mar 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1869

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3f33752 with merge base ddb7f83 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Mar 11, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 1c84b7e
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 11, 2025
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 11, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3da755e
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 11, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3f60a35
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 11, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 322d509
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 11, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c7ddc8e
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 11, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: fb6ba5e
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 11, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4040355
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 12, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8150e3a
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 12, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 846ae8c
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
@vkuzo vkuzo changed the title [wip] triton kernel to cast to mx across dim0 and dim1 triton kernel to cast to mx across dim0 and dim1 Mar 12, 2025
@vkuzo vkuzo added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Mar 12, 2025
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 14, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 19b50b0
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 14, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 962f55e
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 14, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 7235553
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 15, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 1d2e267
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 15, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: cd4a1b1
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 19, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8d708b1
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 19, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 50e9389
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 20, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6d14bec
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 21, 2025
Summary:

Test Plan:

```
python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 08268b1
ghstack-comment-id: 2714865161
Pull Request resolved: #1869
grid_cols = triton.cdiv(n_cols, col_tile_size)

# inner_block_size = 32
rename_me_tile_size = row_tile_size * col_tile_size // inner_block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol i wasnt sure what to call this either..

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good ! did not review indexing math too closely but hoepfully will be replaced with torch.compile soon.

Comment on lines +1530 to +1531
output_col_major.t(),
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the triton is just accessing data_ptr you could initialize these as the correct layout if you wanted.


# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
col_scale_indices = col_scale_indices + (
tl.floor(col_scale_indices / factor) * jump_vals_per_col
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to not just use integer division here instead of / and floor ?

Comment on lines +1465 to +1466
# TODO(future): mask
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some runtime asserts about block divisibility ?

# TODO 1: rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1908/files
# before: 1.7 TB/s
# after: ?
scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suspect theres a better way of getting the scale avoiding doing a float log. i guess this is what the TODO is for.

@vkuzo
Copy link
Contributor Author

vkuzo commented Mar 21, 2025

closing this in favor of #1932, I split the dim1 logic so we can have this PR as an easy historical reference

@vkuzo vkuzo closed this Mar 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants