[2025.08.23] FlashOverlap has been accepted by EuroSys'26 🎉 Tech report will be updated soon.
- demo for GEMM+AllReduce
- predictive search for wave grouping
- multi-node example
- demo for GEMM+ReduceScatter
- demo for GEMM+AlltoAll
- code branch for AE
- more platforms (e.g., hopper GPU)
- end2end example
The figure shows a typical timeline of computation-communication overlap in FlashOverlap. Two CUDA streams are for computation and communication, respectively. The CUTLASS kernel sends signals during GEMM computation in one stream, while a counting kernel stalls NCCL communication until receiving a preset number of signals in the other stream.
The main dependency is NCCL, which FlashOverlap uses for communication. It is convenient to download from the official website. The code has been tested with v2.18.3
and v2.19.3
.
Another dependency is CUTLASS, which is included as submodule. Note that the code has been tested with v3.6.0
and v3.9.0
, but fails with v3.4.0
. We assume CUTLASS>=v3.6.0
works fine.
The code only supports sm_80, sm_86, sm_89
now, and the evaluation enviroments include NVIDIA RTX 3090, RTX 4090, A800, and A100 GPUs. The versions of CUDA Toolkit include CUDA 12.1, 12.2
.
First, pull the repo:
$ git clone https://github.com/infinigence/FlashOverlap.git
$ cd FlashOverlap
$ git submodule update --init --recursive
Install PyTorch and other required packages through pip
or conda
:
$ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
$ pip install numpy==2.1.2, pandas==2.2.3, setuptools==75.8.0
Before compiling, generate the GEMM instances:
$ mkdir ./configs
$ cd ./tool
$ python generate_instances.py
This repo uses cmake (>=3.18) for compiling:
$ cmake -B build
$ cmake --build build -j
Then the operators are registered as torch.class, and in Python code, the .so
should be included whenever the operators are used.
torch.ops.load_library("../build/lib/libst_pybinding.so")
M, N % 128 == 0
).
.
├── cmake
│ └── Modules
│ └── FindNCCL.cmake
├── configs // To store GEMM and overlapping configs
├── example
│ ├── correctness_ar.py // Check correctness of GEMM+AllReduce+RMSNorm
│ ├── correctness_rs.py // Check correctness of GEMM+ReduceScatter+RMSNorm
├── src
│ ├── 3rdparty
│ ├── gemm // CUTLASS GEMM Wrappers
│ │ ├── gemm.cu
│ │ └── gemm.h
│ ├── inc // Instantiate templated GEMMs
│ ├── overlap // Source files for signal+reorder
│ ├── rmsnorm // Source files for reorder+RMSNorm
│ ├── tiling // Tiling definition
│ ├── baseline_impl.cu // Baseline implementation class
│ ├── baseline_impl.h
│ ├── CMakeLists.txt
│ ├── nccl_utils.cu // NCCL id generation function
│ ├── nccl_utils.h
│ ├── overlap_impl.cu // Overlap implementation class
│ ├── overlap_impl.h
│ ├── pybind.cpp
│ └── wait.cuh // Signal kernel
├── test
│ └── test.py
├── tool
│ └── generate_instances.py // Generate templated GEMMs
├── tune
│ ├── bandwidth.py // Bandwidth test for predictive search
│ ├── gen_config.py // Generate GEMM configs based on CUTLASS profiler
│ ├── profile_config.py // Customized profiler
│ └── search.py // Exhausitive search and predictive search
└── CMakeLists.txt
Currently the repo supports two ways to generate the proper configs for GEMMs for better performance. Only one GPU is needed for this operation.
- Make sure the
./configs
dir is created.
$ cd tune
- Using the CUTLASS Profiler. Follow the README and write the profiling results in
$CSV_PATH/*.csv
. Then, generate the.json
file in configs.
$ python gen_config.py --m $M --n $N --k $K --path $CSV_PATH
- Using the customized profiler for a specific shape. The profiling process finishes within minutes. (This method has not been evaluated on RTX 4090 and RTX 3090 yet, will be updated soon.)
$ python profile_config.py --m $M --n $N --k $K
Tune the wave group size. Note multiple GPUs are needed in this program and the environment variable CUDA_VISIBLE_DEVICES
must be set, as we use the spawn
method (torch.multiprocessing.spawn) and the rank and world size are explicitly determined.
- The repo provides both the exhaustive and predictive search methods, and the latter is recommended when
MxN>4096x4096
. If the predictive method is chosen, please generate the bandwidth curve first. Given GPU and communication primitive, the bandwidth curve needs only one generation.
$ CUDA_VISIBLE_DEVICES=0,1 python bandwidth.py --comm_op all_reduce
- Two search methods share the same script,
--predictive_search
should be specified if used.
$ CUDA_VISIBLE_DEVICES=0,1 python search.py --m $M --n $N --k $K --comm_op {all_reduce, reduce_scatter} --predictive_search True
- The generated solution is written into the corresponding
.json
file.
Open the test dir and run the script.
$ cd ./test
$ CUDA_VISIBLE_DEVICES=0,1 python test.py --m $M --n $N --k $K --comm_op {all_reduce, reduce_scatter}
- Open the example dir.
$ cd ./example
- Evaluate the correctness of GEMM+AllReduce+RMSNorm. The RMSNorm must be included as the tile order is corrected in the kernel.
$ CUDA_VISIBLE_DEVICES=0,1 python correctness_{ar, rs}.py --m $M --n $N --k $K
- We define the
ReorderRMSNorm
class inRMSNorm.py
and theOverlapRowParallelLayer
class inRowParallelLayer.py
, which can replace theRMSNorm
class andRowParallelLayer
class, respectively. It's a simple example of usage in end-to-end inference or training.
@misc{hong2025flashoverlap,
title={FlashOverlap: A Lightweight Design for Efficiently Overlapping Communication and Computation},
author={Ke Hong, Xiuhong Li, Minxu Liu, Qiuli Mao, Tianqi Wu, Zixiao Huang, Lufang Chen, Zhong Wang, Yichong Zhang, Zhenhua Zhu, Guohao Dai, Yu Wang},
year={2025},
eprint={2504.19519},
archivePrefix={arXiv},
primaryClass={cs.DC}
}