Skip to content

infinigence/FlashOverlap

Repository files navigation

FlashOverlap

😊 A Lightweight Design for Computation-Communication Overlap

News

[2025.08.23] FlashOverlap has been accepted by EuroSys'26 🎉 Tech report will be updated soon.

Roadmap

  • demo for GEMM+AllReduce
  • predictive search for wave grouping
  • multi-node example
  • demo for GEMM+ReduceScatter
  • demo for GEMM+AlltoAll
  • code branch for AE
  • more platforms (e.g., hopper GPU)
  • end2end example

How FlashOverlap Works

FlashOverlap The figure shows a typical timeline of computation-communication overlap in FlashOverlap. Two CUDA streams are for computation and communication, respectively. The CUTLASS kernel sends signals during GEMM computation in one stream, while a counting kernel stalls NCCL communication until receiving a preset number of signals in the other stream.

Build and Install

Dependency

The main dependency is NCCL, which FlashOverlap uses for communication. It is convenient to download from the official website. The code has been tested with v2.18.3 and v2.19.3.

Another dependency is CUTLASS, which is included as submodule. Note that the code has been tested with v3.6.0 and v3.9.0, but fails with v3.4.0. We assume CUTLASS>=v3.6.0 works fine.

The code only supports sm_80, sm_86, sm_89 now, and the evaluation enviroments include NVIDIA RTX 3090, RTX 4090, A800, and A100 GPUs. The versions of CUDA Toolkit include CUDA 12.1, 12.2.

Install

First, pull the repo:

    $ git clone https://github.com/infinigence/FlashOverlap.git
    $ cd FlashOverlap
    $ git submodule update --init --recursive

Install PyTorch and other required packages through pip or conda:

    $ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
    $ pip install numpy==2.1.2, pandas==2.2.3, setuptools==75.8.0

Before compiling, generate the GEMM instances:

    $ mkdir ./configs
    $ cd ./tool
    $ python generate_instances.py

This repo uses cmake (>=3.18) for compiling:

    $ cmake -B build
    $ cmake --build build -j

Then the operators are registered as torch.class, and in Python code, the .so should be included whenever the operators are used.

    torch.ops.load_library("../build/lib/libst_pybinding.so")

Quick Start

⚠️ Notice: the boundary handling is not implemented, thus the repo only supports regular GEMM shapes now (M, N % 128 == 0).

File Structure

.
├── cmake
│   └── Modules
│       └── FindNCCL.cmake
├── configs                   // To store GEMM and overlapping configs
├── example
│   ├── correctness_ar.py        // Check correctness of GEMM+AllReduce+RMSNorm
│   ├── correctness_rs.py        // Check correctness of GEMM+ReduceScatter+RMSNorm
├── src
│   ├── 3rdparty
│   ├── gemm                  // CUTLASS GEMM Wrappers
│   │   ├── gemm.cu
│   │   └── gemm.h
│   ├── inc                   // Instantiate templated GEMMs
│   ├── overlap               // Source files for signal+reorder
│   ├── rmsnorm               // Source files for reorder+RMSNorm
│   ├── tiling                // Tiling definition  
│   ├── baseline_impl.cu      // Baseline implementation class
│   ├── baseline_impl.h
│   ├── CMakeLists.txt
│   ├── nccl_utils.cu         // NCCL id generation function
│   ├── nccl_utils.h
│   ├── overlap_impl.cu       // Overlap implementation class
│   ├── overlap_impl.h
│   ├── pybind.cpp
│   └── wait.cuh              // Signal kernel
├── test
│   └── test.py
├── tool
│   └── generate_instances.py // Generate templated GEMMs
├── tune
│   ├── bandwidth.py          // Bandwidth test for predictive search
│   ├── gen_config.py         // Generate GEMM configs based on CUTLASS profiler
│   ├── profile_config.py     // Customized profiler
│   └── search.py             // Exhausitive search and predictive search
└── CMakeLists.txt

Generate GEMM configuration

Currently the repo supports two ways to generate the proper configs for GEMMs for better performance. Only one GPU is needed for this operation.

  1. Make sure the ./configs dir is created.
    $ cd tune
  1. Using the CUTLASS Profiler. Follow the README and write the profiling results in $CSV_PATH/*.csv. Then, generate the .json file in configs.
    $ python gen_config.py --m $M --n $N --k $K --path $CSV_PATH
  1. Using the customized profiler for a specific shape. The profiling process finishes within minutes. (This method has not been evaluated on RTX 4090 and RTX 3090 yet, will be updated soon.)
    $ python profile_config.py --m $M --n $N --k $K

Tune

Tune the wave group size. Note multiple GPUs are needed in this program and the environment variable CUDA_VISIBLE_DEVICES must be set, as we use the spawn method (torch.multiprocessing.spawn) and the rank and world size are explicitly determined.

  1. The repo provides both the exhaustive and predictive search methods, and the latter is recommended when MxN>4096x4096. If the predictive method is chosen, please generate the bandwidth curve first. Given GPU and communication primitive, the bandwidth curve needs only one generation.
    $ CUDA_VISIBLE_DEVICES=0,1 python bandwidth.py --comm_op all_reduce
  1. Two search methods share the same script, --predictive_search should be specified if used.
    $ CUDA_VISIBLE_DEVICES=0,1 python search.py --m $M --n $N --k $K --comm_op {all_reduce, reduce_scatter} --predictive_search True
  1. The generated solution is written into the corresponding .json file.

Speed Test

Open the test dir and run the script.

    $ cd ./test
    $ CUDA_VISIBLE_DEVICES=0,1 python test.py --m $M --n $N --k $K --comm_op {all_reduce, reduce_scatter}

Correctness Test

  1. Open the example dir.
    $ cd ./example
  1. Evaluate the correctness of GEMM+AllReduce+RMSNorm. The RMSNorm must be included as the tile order is corrected in the kernel.
    $ CUDA_VISIBLE_DEVICES=0,1 python correctness_{ar, rs}.py --m $M --n $N --k $K
  1. We define the ReorderRMSNorm class in RMSNorm.py and the OverlapRowParallelLayer class in RowParallelLayer.py, which can replace the RMSNorm class and RowParallelLayer class, respectively. It's a simple example of usage in end-to-end inference or training.

Citation

    @misc{hong2025flashoverlap,
      title={FlashOverlap: A Lightweight Design for Efficiently Overlapping Communication and Computation},
      author={Ke Hong, Xiuhong Li, Minxu Liu, Qiuli Mao, Tianqi Wu, Zixiao Huang, Lufang Chen, Zhong Wang, Yichong Zhang, Zhenhua Zhu, Guohao Dai, Yu Wang},
      year={2025},
      eprint={2504.19519},
      archivePrefix={arXiv},
      primaryClass={cs.DC}
    }

About

A lightweight design for computation-communication overlap.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published