Skip to content

Commit

Permalink
Sm80 Allgather-gemm && Gemm-Reducescatter
Browse files Browse the repository at this point in the history
The current implementation supports two operations on the SM80 architecture: (#3)
1. Allgather followed by GEMM (General Matrix-Matrix Multiplication)
2. GEMM followed by Reduce-Scatter

The fused operations demonstrate improved performance compared to invoking GEMM
and communication operations separately. This optimization is crucial for high-performance
computing tasks especially for LLM training or inference..

Co-authored-by: Chengquan Jiang <imjcqt@gmail.com>
Co-authored-by: Wenlei Bao <wenlei.bao@bytedance.com>
Co-authored-by: Ningxin Zheng <zhengningxin@bytedance.com>
Co-authored-by: Qi Hou <houqi1993@gmail.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
Co-authored-by: Xin Liu <liuxin.ai@bytedance.com>
Co-authored-by: Liwen Chang <liwen.chang@bytedance.com>
Co-authored-by: Haibin Lin <haibin.lin@bytedance.com>
  • Loading branch information
9 people authored Jun 6, 2024
1 parent de92387 commit 98fbde4
Show file tree
Hide file tree
Showing 141 changed files with 41,716 additions and 15 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass
[submodule "3rdparty/nccl"]
path = 3rdparty/nccl
url = https://github.com/NVIDIA/nccl
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 548 files
1 change: 1 addition & 0 deletions 3rdparty/nccl
Submodule nccl added at 8c6c59
13 changes: 0 additions & 13 deletions 3rdparty/patches/cutlass/copy_sm90_desc.patch

This file was deleted.

172 changes: 172 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
project(FLUX LANGUAGES CXX CUDA)

list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/modules/")

# cmake global settings
set(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE INTERNAL "")
set(BUILD_THS ON CACHE INTERNAL "Build Torch op")
set(BUILD_TEST ON CACHE INTERNAL "Build unit tests")
set(ENABLE_NVSHMEM ON CACHE INTERNAL "Use NVSHMEM to transfer data")
set(CUTLASS_TRACE OFF CACHE INTERNAL "Print CUTLASS Host Trace info")
set(FLUX_DEBUG OFF CACHE INTERNAL "Define FLUX_DEBUG")
message("PYTHONPATH: ${PYTHONPATH}")
message("NVShmem Support: ${ENABLE_NVSHMEM}")

# find cuda
find_package(CUDAToolkit REQUIRED)

message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}")
if(CUDAToolkit_VERSION VERSION_LESS "11.0")
message(FATAL_ERROR "requires cuda to be >= 11.0")
elseif(CUDAToolkit_VERSION VERSION_LESS "12.0")
set(CUDAARCHS "80" CACHE STRING "CUDA Architectures")
else()
set(CUDAARCHS "80;90" CACHE STRING "CUDA Architectures")
endif()

set(CMAKE_CUDA_ARCHITECTURES ${CUDAARCHS})
message(STATUS "CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}")

set(CUDA_ARCH_FLAGS)
foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES})
list(APPEND CUDA_ARCH_FLAGS "-gencode=arch=compute_${ARCH},code=\\\"sm_${ARCH},compute_${ARCH}\\\"")
endforeach()

string(JOIN " " JOINED_CUDA_ARCH_FLAGS ${CUDA_ARCH_FLAGS})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${JOINED_CUDA_ARCH_FLAGS}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DNDEBUG")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wno-psabi")
# set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_DEBUG_TRACE_LEVEL=0")

set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall")

set(CMAKE_CXX_STANDARD "17")
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++17")

if(CUTLASS_TRACE)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_DEBUG_TRACE_LEVEL=1")
endif()

if(FLUX_DEBUG)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DFLUX_DEBUG")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DFLUX_DEBUG")
endif()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -O3")

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)

if(BUILD_THS)
message("Build THS on")
find_package(Python 3 REQUIRED)
find_program(PYTHON_EXECUTABLE NAMES python3 python)

set(TORCH_CUDA_ARCH_LIST)
foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES})
if(ARCH STREQUAL "80")
list(APPEND TORCH_CUDA_ARCH_LIST "8.0")
elseif(ARCH STREQUAL "89")
list(APPEND TORCH_CUDA_ARCH_LIST "8.9")
elseif(ARCH STREQUAL "90")
list(APPEND TORCH_CUDA_ARCH_LIST "9.0")
else()
message(WARNING "Unsupported CUDA arch [${ARCH}] for TORCH_CUDA_ARCH")
endif()
endforeach()

execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c" "from __future__ import print_function; import os; import torch;
print(os.path.dirname(torch.__file__),end='');"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE TORCH_DIR)
if (NOT _PYTHON_SUCCESS MATCHES 0)
message("PY:${PYTHONPATH}")
message(FATAL_ERROR "Torch config Error.")
endif()
list(APPEND CMAKE_PREFIX_PATH ${TORCH_DIR})
find_package(Torch REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_DIR}/lib")

execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c" "from __future__ import print_function; from distutils import sysconfig;
print(sysconfig.get_python_inc());
print(sysconfig.get_config_var('EXT_SUFFIX'));"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE _PYTHON_VALUES)
if (NOT _PYTHON_SUCCESS MATCHES 0)
message("PY:${PYTHON_EXECUTABLE}")
message(FATAL_ERROR "Python config Error.")
endif()
string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES})
string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
list(GET _PYTHON_VALUES 0 PY_INCLUDE_DIR)
list(GET _PYTHON_VALUES 1 PY_SUFFIX)
list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR})

execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c"
"from torch.utils import cpp_extension; import re; import torch; \
version = tuple(int(i) for i in re.match('(\\d+)\\.(\\d+)\\.(\\d+)', torch.__version__).groups()); \
args = ([],True,False,False) if version >= (1, 8, 0) else ([],True,False); \
print(' '.join(cpp_extension._prepare_ldflags(*args)),end='');"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE TORCH_LINK)
message("-- TORCH_LINK ${TORCH_LINK}")
if (NOT _PYTHON_SUCCESS MATCHES 0)
message(FATAL_ERROR "PyTorch link config Error.")
endif()
endif()

# force use sm90a for cutlass
string(REGEX REPLACE "sm_90" "sm_90a" CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS})
string(REGEX REPLACE "compute_90" "compute_90a" CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS})

set(COMMON_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include
${CUDAToolkit_INCLUDE_DIRS}
${CMAKE_CURRENT_BINARY_DIR}
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/tools/util/include
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/tools/library/include
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/tools/profiler/include
)

set(COMMON_LIB_DIRS "")
list(APPEND COMMON_LIB_DIRS "${CUDAToolkit_LIBRARY_DIR}")

if(ENABLE_NVSHMEM)
add_definitions(-DFLUX_USE_NVSHMEM)
set(NVSHMEM_BUILD_DIR ${PROJECT_SOURCE_DIR}/3rdparty/nvshmem/build)
message(STATUS "NVSHMEM build dir: ${NVSHMEM_BUILD_DIR}")
if(NOT EXISTS ${NVSHMEM_BUILD_DIR})
message(FATAL_ERROR "NVSHMEM not found. Please run ./build_nvshmem.sh first.")
endif()
list(APPEND COMMON_HEADER_DIRS "${NVSHMEM_BUILD_DIR}/src/include")
list(APPEND COMMON_LIB_DIRS "${NVSHMEM_BUILD_DIR}/src/lib")
endif()

# append headers explicitly for .cu files, in order to enable vscode clangd intellisense
foreach(inc_dir ${COMMON_HEADER_DIRS})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -I${inc_dir}")
endforeach()
message("final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")

include_directories(
${COMMON_HEADER_DIRS}
)

link_directories(
${COMMON_LIB_DIRS}
)

add_subdirectory(src)
70 changes: 70 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->

# How to Contribute

We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.

## Contributor License Agreement

Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project.

You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.

## Changes Accepted

Please file issues before doing substantial work; this will ensure that others
don't duplicate the work and that there's a chance to discuss any design issues.

Changes only tweaking style are unlikely to be accepted unless they are applied
consistently across the project. Most of the code style is derived from the
[Google Style Guides](http://google.github.io/styleguide/) for the appropriate
language and is generally not something we accept changes on (as clang-format
and clang-tidy handle that for us). The compiler portion of the project follows
[MLIR style](https://mlir.llvm.org/getting_started/DeveloperGuide/#style-guide).
Improvements to code structure and clarity are welcome but please file issues to
track such work first.

## AUTHORS file

If you would like to receive additional recognition for your contribution, you
may add yourself (or your organization) to the AUTHORS file. This keeps track of
those who have made significant contributions to the project. Please add the
entity who owns the copyright for your contribution. The source control history
remains the most accurate source for individual contributions.

## Pull Requests
We actively welcome your pull requests.

1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Issues

We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

5 changes: 5 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
include src/ths_op/*.cc.inc
exclude pynvshmem/
recursive-include src *
recursive-include include *
recursive-include python/flux_ths_pybind
47 changes: 46 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,52 @@

Flux is a fast communication-overlapping library for tensor parallelism on GPUs.

## Coming Soon

## Why Flux

Flux significantly can reduce latency and increase throughput for tensor parallelism for both inference and training.

## Build
```bash
# Download nvshmem-2.11(https://developer.nvidia.com/nvshmem) and place it to flux/3rdparty/nvshmem
# Flux is temporarily dependent on a specific version of nvshmem (2.11).
tar Jxvf nvshmem_src_2.11.0-5.txz
mv nvshmem_src_2.11.0-5 ${YOUR_PATH}/flux/3rdparty/nvshmem
git submodule update --init --recursive
# Ampere
./build.sh --arch 80

```

If you are tired of the cmake process, you can set environment variable `FLUX_BUILD_SKIP_CMAKE` to 1 to skip cmake if `build/CMakeCache.txt` already exists.

If you want to build a wheel package, add `--package` to the build command. find the output wheel file under dist/

```
# Ampere
./build.sh --arch 80 --package
```

For development release, run build script with `FLUX_FINAL_RELEASE=0`.

```
# Ampere
FLUX_FINAL_RELEASE=0 ./build.sh --arch 80 --package
```

## Run Demo
```
# gemm only
PYTHONPATH=./python:$PYTHONPATH python3 test/test_gemm_only.py 4096 12288 6144 --dtype=float16
# gemm fused with reduce-scatter
./launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10
# all-gather fused with gemm
./launch.sh test/test_ag_kernel.py 4096 49152 12288 --dtype=float16 --iters=10
```



## [License](./LICENSE)

Expand Down
Loading

0 comments on commit 98fbde4

Please sign in to comment.