Skip to content

[ET-VK] Efficient tiled int8 matmul #9804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2025
Merged

[ET-VK] Efficient tiled int8 matmul #9804

merged 2 commits into from
Apr 1, 2025

Conversation

pytorchbot
Copy link
Collaborator

This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #9766 by @SS-JIA
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/205/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/205/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/204/orig
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/205/orig
@diff-train-skip-merge

SS-JIA added 2 commits March 31, 2025 12:18
Pull Request resolved: #9765

## Context

The weight tensor of a linear layer is usually stored in a transposed manner, such that when computing the matrix multiplication, the reduction traverses along the rows of the weight tensor as opposed to the columns. This results in a better memory access pattern for CPUs.

However, for GPUs, I have found that "un-transposing" the weight tensors result in better performance. This is likely due to the fact since GPUs can compute multiple output elements in parallel, reading along the columns allows for coalescing memory loads among threads in a work group.

## Changes

* Introduce the ability to transpose height and weight dims when transferring tensor data to the GPU.
* Prepackthe weight tensor "un-transposed" for the int8 quantized linear operator
ghstack-source-id: 275180033
@exported-using-ghexport

Differential Revision: [D72066588](https://our.internmc.facebook.com/intern/diff/D72066588/)
Pull Request resolved: #9766

## Context

Introduce a optimized tiled implementation for computing the weight int8-quantized linear operation.

This implementation takes advantage of the following principles to squeeze out performance:

* Compute an output tile with each thread, rather than a single output element. This allows for better memory re-use of loaded input tensor data.
* Compute the output tile by iteratively loading tiles of the input matrices, caching them in registers, and then performing the `fma` accumulations to obtain a partial output. By splitting the data loading and computation into distinct steps, the GPU is able to perform latency hiding more effectively, i.e. switching to a warp that needs to perform compute when the current warp is waiting on data load
* Use a work group size of `{N, 1, 1}`. This makes it so that all the threads in a work group load the same row of the input matrx, and consecutive columns of the weight matrix. This way, the row of the input is kept hot in the cache, and accesses to the weight matrix can be coalesced due to the previous diff un-transposing the weight matrix.

Differential Revision: [D72066587](https://our.internmc.facebook.com/intern/diff/D72066587/)
ghstack-source-id: 275180032
@pytorchbot pytorchbot requested a review from SS-JIA as a code owner April 1, 2025 16:15
Copy link

pytorch-bot bot commented Apr 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9804

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 17 Pending

As of commit 1907ae2 with merge base 2aa7748 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 1, 2025
Base automatically changed from gh/SS-JIA/204/orig to main April 1, 2025 16:41
@SS-JIA SS-JIA merged commit 97bb055 into main Apr 1, 2025
79 of 80 checks passed
@SS-JIA SS-JIA deleted the gh/SS-JIA/205/orig branch April 1, 2025 16:43
kirklandsign pushed a commit that referenced this pull request Apr 11, 2025
Pull Request resolved: #9766

## Context

Introduce a optimized tiled implementation for computing the weight int8-quantized linear operation.

This implementation takes advantage of the following principles to squeeze out performance:

* Compute an output tile with each thread, rather than a single output element. This allows for better memory re-use of loaded input tensor data.
* Compute the output tile by iteratively loading tiles of the input matrices, caching them in registers, and then performing the `fma` accumulations to obtain a partial output. By splitting the data loading and computation into distinct steps, the GPU is able to perform latency hiding more effectively, i.e. switching to a warp that needs to perform compute when the current warp is waiting on data load
* Use a work group size of `{N, 1, 1}`. This makes it so that all the threads in a work group load the same row of the input matrx, and consecutive columns of the weight matrix. This way, the row of the input is kept hot in the cache, and accesses to the weight matrix can be coalesced due to the previous diff un-transposing the weight matrix.

Differential Revision: [D72066587](https://our.internmc.facebook.com/intern/diff/D72066587/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants