You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
## Context
Introduce a optimized tiled implementation for computing the weight int8-quantized linear operation.
This implementation takes advantage of the following principles to squeeze out performance:
* Compute an output tile with each thread, rather than a single output element. This allows for better memory re-use of loaded input tensor data.
* Compute the output tile by iteratively loading tiles of the input matrices, caching them in registers, and then performing the `fma` accumulations to obtain a partial output. By splitting the data loading and computation into distinct steps, the GPU is able to perform latency hiding more effectively, i.e. switching to a warp that needs to perform compute when the current warp is waiting on data load
* Use a work group size of `{N, 1, 1}`. This makes it so that all the threads in a work group load the same row of the input matrx, and consecutive columns of the weight matrix. This way, the row of the input is kept hot in the cache, and accesses to the weight matrix can be coalesced due to the previous diff un-transposing the weight matrix.
Differential Revision: [D72066587](https://our.internmc.facebook.com/intern/diff/D72066587/)
[ghstack-poisoned]
0 commit comments