Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "deep-gemm/torch-ext/deep_gemm/include/third-party/cutlass"]
path = deep-gemm/torch-ext/deep_gemm/include/third-party/cutlass
url = https://github.com/NVIDIA/cutlass.git
33 changes: 33 additions & 0 deletions deep-gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT
cmake_minimum_required(VERSION 3.10)
project(deep_gemm LANGUAGES CXX CUDA)
set(CMAKE_VERBOSE_MAKEFILE ON)

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")

set(USE_SYSTEM_NVTX on)
set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")

find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)

include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)

# The main Python API entrance
pybind11_add_module(_C csrc/python_api.cpp)
target_link_libraries(_C PRIVATE ${TORCH_LIBRARIES} torch_python)

# Enable kernel code indexing with CMake-based IDEs
cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu)
21 changes: 21 additions & 0 deletions deep-gemm/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2025 DeepSeek

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
179 changes: 179 additions & 0 deletions deep-gemm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# DeepGEMM

DeepGEMM is a library designed for clean and efficient General Matrix Multiplications (GEMMs). It supports FP8 and BF16 (working in progress) for both normal and Mix-of-Experts (MoE) grouped scenarios. Written in CUDA, the library has no kernel compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module.

DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only a limited number of core kernel functions. This makes it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques.

Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.

## News

- 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2.
- Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details.
- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module.
- NVRTC and post-compilation SASS optimization are all disabled.
- NVRTC will be supported later.
- As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported.
- Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details.
- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details.
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases).
- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details.

## Roadmap

- [x] More correctness tests for grouped-contiguous layout
- [x] Shared memory swizzling for output
- [x] MoE scheduler with TMA multicast compatibility
- [x] Fix TMA multicast compatibility for indivisible shapes
- [x] Skip useless computation on M
- [x] NVRTC as a faster compiler
- [x] Sanitizer for testing
- [x] Weight gradient kernels for dense models
- [x] Weight gradient kernels for MoE models
- [ ] Better `get_best_configs` modeling
- [ ] CUDA PDL support
- [ ] Larger TMA multicast size for some shapes
- [x] MMA template refactor with CUTLASS
- [x] Remove shape limitations on N and K
- [x] BF16 kernels
- [ ] Split/stream-k optimizations
- [ ] Ampere kernels
- [ ] Polish docs

## Quick start

### Requirements

- NVIDIA SM90 or SM100 architecture GPU
- Python 3.8 or higher
- Compilers with C++20 support
- CUDA Toolkit:
- CUDA 12.3 or higher for SM90
- **We highly recommend 12.9 or higher for the best performance**
- CUDA 12.9 or higher for SM100
- PyTorch 2.1 or higher
- CUTLASS 4.0 or higher (could be cloned by Git submodule)
- `{fmt}` library (could be cloned by Git submodule)

### Development

```bash
# Submodule must be cloned
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
cd DeepGEMM

# Link some essential includes and build the CPP JIT module
cat develop.sh
./develop.sh

# Test all GEMM implements
python tests/test_layout.py
python tests/test_attention.py
python tests/test_core.py
```

### Installation

```bash
cat install.sh
./install.sh
```

Then, import `deep_gemm` in your Python project, and enjoy!

## Interfaces

#### Notices

This library provides optimized GEMM kernels for NVIDIA GPUs with a naming convention: `D = C + A @ B`. The input shape layout is NT (non-transposed A, transposed B). While the SM90 implementation supports only the NT memory layout (row-major, col-major), the SM100 implementation supports all memory layouts (NT, TN, NN, TT). For example, `fp8_gemm_nt` will do a `D = C + A @ B.T`

For both architectures, the LHS scaling factor is required to have a TMA-aligned and transposed layout. And the data format for the scaling factor of SM90 and SM100 is different:

- SM90 requires scaling factors in FP32 format.
- SM100 requires scaling factors in packed [UE8M0](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats) format, which packs 4 UE8M0 into a single `torch.int`.

Please note that operations like input transposition or FP8 casting must be handled separately by the user, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves.

#### Normal dense GEMMs (non-grouped)

To perform a basic non-grouped FP8 GEMM, call the `fp8_gemm_{nt, nn, tn, tt}` function. For more details, please refer to the function documentation.

#### Grouped GEMMs (contiguous layout)

Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_mk_alignment_for_contiguous_layout()`). For more information, please refer to the `m_grouped_fp8_gemm_{nt, nn}_contiguous` function documentation.

We also provide a K-axis-grouped API for MoE weight backward (with M and N must remain fixed), please refer to `k_grouped_fp8_gemm_tn_contiguous` for more information.

#### Grouped GEMMs (masked layout)

During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions.

Use `m_grouped_fp8_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input.

#### V3.2 MQA kernels for the indexer

The kernel family has two versions, non-paged (for prefilling) and paged (for decoding).
Take the non-paged version `fp8_mqa_logits` as an example. It has 6 inputs:

- `q`, E4M3 tensor with shape `[seq_len, num_heads, head_dim]`
- `kv`, E4M3 tensor (shaped as `[seq_len_kv, head_dim]`) with float SF (shaped as `[seq_len_kv]`)
- `weights`, float tensor with shape `[seq_len, num_heads]`
- `cu_seq_len_k_start` and `cu_seq_len_k_end`, int tensor with shape `[seq_len]`
- `clean_logits`, whether to clean the unfilled logits into `-inf`

The output tensor is shaped as `[seq_len, seq_len_kv]`, indicating token-to-token logits.
For each token `i` in `q`, it will iterate all tokens `j` from `[cu_seq_len_k_start[i], cu_seq_len_k_end[i])`,
and calculate the logit `out[i, j]` as:

```python
kv_j = kv[0][j, :] * kv[1][j].unsqueeze(1) # [head_dim]
out_ij = q[i, :, :] @ kv_j # [num_heads]
out_ij = out_ij.relu() * weights[i, :] # [num_heads]
out_ij = out_ij.sum() # Scalar
```

For more details and the paged version `fp8_paged_mqa_logits`, please refer to `tests/test_attention.py`.

#### Utilities

The library provides some utility functions besides the above kernels:

- `deep_gemm.set_num_sms`: set the maximum SM count to use
- `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set)
- `deep_gemm.set_tc_util`: set an approximated tensor core utilization ratio
- `deep_gemm.get_tc_util`: get the current tensor core utilization ratio
- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout
- `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size
- `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout
- `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor
- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0)
- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel

The library also provides some environment variables, which may be useful:

- General
- `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default
- JIT cache related
- `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
- NVCC/NVRTC selections
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default
- `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default
- Compiler options
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default
- Heuristic selection
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default

For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.

## Acknowledgement

DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers!

## License

This code repository is released under [the MIT License](LICENSE).

--

vendored at commit 477618cd51baffca09c4b0b87e97c03fe827ef03
12 changes: 12 additions & 0 deletions deep-gemm/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Change current directory into project root
original_dir=$(pwd)
script_dir=$(realpath "$(dirname "$0")")
cd "$script_dir"

# Remove old dist file, build files, and install
rm -rf build dist
rm -rf *.egg-info
python setup.py bdist_wheel

# Open users' original directory
cd "$original_dir"
104 changes: 104 additions & 0 deletions deep-gemm/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
[general]
name = "deep-gemm"
backends = ["cuda"]
version = 1

[general.cuda]
minver = "12.3"

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h",
]
pyext = ["py", "cuh", "hpp", "h"]

[kernel.deep_gemm]
backend = "cuda"
cuda-minver = "12.3"
cuda-capabilities = ["9.0a"]
depends = ["torch", "cutlass_4_0"]
include = [
"csrc",
"deep_gemm/include",
]
cuda-flags = [
"-std=c++17",
"-O3",
"-DNDEBUG",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"-Wno-deprecated-declarations",
]
src = [
# Compiled source
"csrc/impl.cu",
# API headers
"csrc/apis/attention.hpp",
"csrc/apis/einsum.hpp",
"csrc/apis/gemm.hpp",
"csrc/apis/hyperconnection.hpp",
"csrc/apis/layout.hpp",
"csrc/apis/runtime.hpp",
# JIT infrastructure headers
"csrc/jit/cache.hpp",
"csrc/jit/compiler.hpp",
"csrc/jit/device_runtime.hpp",
"csrc/jit/handle.hpp",
"csrc/jit/kernel_runtime.hpp",
# JIT kernel heuristics
"csrc/jit_kernels/heuristics/common.hpp",
"csrc/jit_kernels/heuristics/sm100.hpp",
"csrc/jit_kernels/heuristics/sm90.hpp",
# JIT kernel implementations
"csrc/jit_kernels/impls/epilogue.hpp",
"csrc/jit_kernels/impls/runtime_utils.hpp",
"csrc/jit_kernels/impls/sm100_bf16_gemm.hpp",
"csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp",
"csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp",
"csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp",
"csrc/jit_kernels/impls/sm90_bf16_gemm.hpp",
"csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp",
"csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp",
"csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp",
"csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp",
"csrc/jit_kernels/impls/smxx_clean_logits.hpp",
"csrc/jit_kernels/impls/smxx_cublaslt.hpp",
"csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp",
"csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp",
"csrc/jit_kernels/impls/smxx_layout.hpp",
# Utility headers
"csrc/utils/compatibility.hpp",
"csrc/utils/exception.hpp",
"csrc/utils/format.hpp",
"csrc/utils/hash.hpp",
"csrc/utils/layout.hpp",
"csrc/utils/lazy_init.hpp",
"csrc/utils/math.hpp",
"csrc/utils/system.hpp",
# Runtime JIT headers (deep_gemm/include)
"deep_gemm/include/deep_gemm/common/cute_tie.cuh",
"deep_gemm/include/deep_gemm/common/epilogue_utils.cuh",
"deep_gemm/include/deep_gemm/common/types.hpp",
"deep_gemm/include/deep_gemm/common/sm90_utils.cuh",
"deep_gemm/include/deep_gemm/common/reduction.cuh",
"deep_gemm/include/deep_gemm/common/utils.cuh",
"deep_gemm/include/deep_gemm/common/tma_utils.cuh",
"deep_gemm/include/deep_gemm/common/sm100_utils.cuh",
"deep_gemm/include/deep_gemm/common/scheduler.cuh",
"deep_gemm/include/deep_gemm/impls/smxx_layout.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh",
"deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh",
]
Loading
Loading