A lighweight library exposing grouped GEMM kernels in PyTorch.
Run pip install grouped_gemm
to install the package.
By default, the installed package runs in conservative (cuBLAS
) mode:
it launches one GEMM kernel per batch element instead of using a single
grouped GEMM kernel for the whole batch.
To enable using grouped GEMM kernels, you need to switch to the CUTLASS
mode by setting the GROUPED_GEMM_CUTLASS
environment variable to 1
when building the library. For example, to build the library in CUTLASS
mode for Ampere (SM 8.0), clone the repository and run the following:
$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install .
See this comment for some performance measurements on A100 and H100.
- Running grouped GEMM kernels without GPU<->CPU synchronization points.
- Hopper-optimized grouped GEMM kernels.