-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sm80 Allgather-gemm && Gemm-Reducescatter
The current implementation supports two operations on the SM80 architecture: (#3) 1. Allgather followed by GEMM (General Matrix-Matrix Multiplication) 2. GEMM followed by Reduce-Scatter The fused operations demonstrate improved performance compared to invoking GEMM and communication operations separately. This optimization is crucial for high-performance computing tasks especially for LLM training or inference.. Co-authored-by: Chengquan Jiang <imjcqt@gmail.com> Co-authored-by: Wenlei Bao <wenlei.bao@bytedance.com> Co-authored-by: Ningxin Zheng <zhengningxin@bytedance.com> Co-authored-by: Qi Hou <houqi1993@gmail.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com> Co-authored-by: Xin Liu <liuxin.ai@bytedance.com> Co-authored-by: Liwen Chang <liwen.chang@bytedance.com> Co-authored-by: Haibin Lin <haibin.lin@bytedance.com>
- Loading branch information
1 parent
de92387
commit 98fbde4
Showing
141 changed files
with
41,716 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[submodule "3rdparty/cutlass"] | ||
path = 3rdparty/cutlass | ||
url = https://github.com/NVIDIA/cutlass | ||
[submodule "3rdparty/nccl"] | ||
path = 3rdparty/nccl | ||
url = https://github.com/NVIDIA/nccl |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
cmake_minimum_required(VERSION 3.17 FATAL_ERROR) | ||
project(FLUX LANGUAGES CXX CUDA) | ||
|
||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/modules/") | ||
|
||
# cmake global settings | ||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE INTERNAL "") | ||
set(BUILD_THS ON CACHE INTERNAL "Build Torch op") | ||
set(BUILD_TEST ON CACHE INTERNAL "Build unit tests") | ||
set(ENABLE_NVSHMEM ON CACHE INTERNAL "Use NVSHMEM to transfer data") | ||
set(CUTLASS_TRACE OFF CACHE INTERNAL "Print CUTLASS Host Trace info") | ||
set(FLUX_DEBUG OFF CACHE INTERNAL "Define FLUX_DEBUG") | ||
message("PYTHONPATH: ${PYTHONPATH}") | ||
message("NVShmem Support: ${ENABLE_NVSHMEM}") | ||
|
||
# find cuda | ||
find_package(CUDAToolkit REQUIRED) | ||
|
||
message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}") | ||
if(CUDAToolkit_VERSION VERSION_LESS "11.0") | ||
message(FATAL_ERROR "requires cuda to be >= 11.0") | ||
elseif(CUDAToolkit_VERSION VERSION_LESS "12.0") | ||
set(CUDAARCHS "80" CACHE STRING "CUDA Architectures") | ||
else() | ||
set(CUDAARCHS "80;90" CACHE STRING "CUDA Architectures") | ||
endif() | ||
|
||
set(CMAKE_CUDA_ARCHITECTURES ${CUDAARCHS}) | ||
message(STATUS "CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}") | ||
|
||
set(CUDA_ARCH_FLAGS) | ||
foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES}) | ||
list(APPEND CUDA_ARCH_FLAGS "-gencode=arch=compute_${ARCH},code=\\\"sm_${ARCH},compute_${ARCH}\\\"") | ||
endforeach() | ||
|
||
string(JOIN " " JOINED_CUDA_ARCH_FLAGS ${CUDA_ARCH_FLAGS}) | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${JOINED_CUDA_ARCH_FLAGS}") | ||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DNDEBUG") | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wno-psabi") | ||
# set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall") | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing") | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_DEBUG_TRACE_LEVEL=0") | ||
|
||
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0") | ||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0") | ||
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall") | ||
|
||
set(CMAKE_CXX_STANDARD "17") | ||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++17") | ||
|
||
if(CUTLASS_TRACE) | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_DEBUG_TRACE_LEVEL=1") | ||
endif() | ||
|
||
if(FLUX_DEBUG) | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DFLUX_DEBUG") | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DFLUX_DEBUG") | ||
endif() | ||
|
||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -O3") | ||
|
||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) | ||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) | ||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) | ||
|
||
if(BUILD_THS) | ||
message("Build THS on") | ||
find_package(Python 3 REQUIRED) | ||
find_program(PYTHON_EXECUTABLE NAMES python3 python) | ||
|
||
set(TORCH_CUDA_ARCH_LIST) | ||
foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES}) | ||
if(ARCH STREQUAL "80") | ||
list(APPEND TORCH_CUDA_ARCH_LIST "8.0") | ||
elseif(ARCH STREQUAL "89") | ||
list(APPEND TORCH_CUDA_ARCH_LIST "8.9") | ||
elseif(ARCH STREQUAL "90") | ||
list(APPEND TORCH_CUDA_ARCH_LIST "9.0") | ||
else() | ||
message(WARNING "Unsupported CUDA arch [${ARCH}] for TORCH_CUDA_ARCH") | ||
endif() | ||
endforeach() | ||
|
||
execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c" "from __future__ import print_function; import os; import torch; | ||
print(os.path.dirname(torch.__file__),end='');" | ||
RESULT_VARIABLE _PYTHON_SUCCESS | ||
OUTPUT_VARIABLE TORCH_DIR) | ||
if (NOT _PYTHON_SUCCESS MATCHES 0) | ||
message("PY:${PYTHONPATH}") | ||
message(FATAL_ERROR "Torch config Error.") | ||
endif() | ||
list(APPEND CMAKE_PREFIX_PATH ${TORCH_DIR}) | ||
find_package(Torch REQUIRED) | ||
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_DIR}/lib") | ||
|
||
execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c" "from __future__ import print_function; from distutils import sysconfig; | ||
print(sysconfig.get_python_inc()); | ||
print(sysconfig.get_config_var('EXT_SUFFIX'));" | ||
RESULT_VARIABLE _PYTHON_SUCCESS | ||
OUTPUT_VARIABLE _PYTHON_VALUES) | ||
if (NOT _PYTHON_SUCCESS MATCHES 0) | ||
message("PY:${PYTHON_EXECUTABLE}") | ||
message(FATAL_ERROR "Python config Error.") | ||
endif() | ||
string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES}) | ||
string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES}) | ||
list(GET _PYTHON_VALUES 0 PY_INCLUDE_DIR) | ||
list(GET _PYTHON_VALUES 1 PY_SUFFIX) | ||
list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR}) | ||
|
||
execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c" | ||
"from torch.utils import cpp_extension; import re; import torch; \ | ||
version = tuple(int(i) for i in re.match('(\\d+)\\.(\\d+)\\.(\\d+)', torch.__version__).groups()); \ | ||
args = ([],True,False,False) if version >= (1, 8, 0) else ([],True,False); \ | ||
print(' '.join(cpp_extension._prepare_ldflags(*args)),end='');" | ||
RESULT_VARIABLE _PYTHON_SUCCESS | ||
OUTPUT_VARIABLE TORCH_LINK) | ||
message("-- TORCH_LINK ${TORCH_LINK}") | ||
if (NOT _PYTHON_SUCCESS MATCHES 0) | ||
message(FATAL_ERROR "PyTorch link config Error.") | ||
endif() | ||
endif() | ||
|
||
# force use sm90a for cutlass | ||
string(REGEX REPLACE "sm_90" "sm_90a" CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS}) | ||
string(REGEX REPLACE "compute_90" "compute_90a" CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS}) | ||
|
||
set(COMMON_HEADER_DIRS | ||
${PROJECT_SOURCE_DIR}/include | ||
${CUDAToolkit_INCLUDE_DIRS} | ||
${CMAKE_CURRENT_BINARY_DIR} | ||
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include | ||
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/tools/util/include | ||
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/tools/library/include | ||
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/tools/profiler/include | ||
) | ||
|
||
set(COMMON_LIB_DIRS "") | ||
list(APPEND COMMON_LIB_DIRS "${CUDAToolkit_LIBRARY_DIR}") | ||
|
||
if(ENABLE_NVSHMEM) | ||
add_definitions(-DFLUX_USE_NVSHMEM) | ||
set(NVSHMEM_BUILD_DIR ${PROJECT_SOURCE_DIR}/3rdparty/nvshmem/build) | ||
message(STATUS "NVSHMEM build dir: ${NVSHMEM_BUILD_DIR}") | ||
if(NOT EXISTS ${NVSHMEM_BUILD_DIR}) | ||
message(FATAL_ERROR "NVSHMEM not found. Please run ./build_nvshmem.sh first.") | ||
endif() | ||
list(APPEND COMMON_HEADER_DIRS "${NVSHMEM_BUILD_DIR}/src/include") | ||
list(APPEND COMMON_LIB_DIRS "${NVSHMEM_BUILD_DIR}/src/lib") | ||
endif() | ||
|
||
# append headers explicitly for .cu files, in order to enable vscode clangd intellisense | ||
foreach(inc_dir ${COMMON_HEADER_DIRS}) | ||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -I${inc_dir}") | ||
endforeach() | ||
message("final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") | ||
|
||
include_directories( | ||
${COMMON_HEADER_DIRS} | ||
) | ||
|
||
link_directories( | ||
${COMMON_LIB_DIRS} | ||
) | ||
|
||
add_subdirectory(src) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
<!--- Licensed to the Apache Software Foundation (ASF) under one --> | ||
<!--- or more contributor license agreements. See the NOTICE file --> | ||
<!--- distributed with this work for additional information --> | ||
<!--- regarding copyright ownership. The ASF licenses this file --> | ||
<!--- to you under the Apache License, Version 2.0 (the --> | ||
<!--- "License"); you may not use this file except in compliance --> | ||
<!--- with the License. You may obtain a copy of the License at --> | ||
|
||
<!--- http://www.apache.org/licenses/LICENSE-2.0 --> | ||
|
||
<!--- Unless required by applicable law or agreed to in writing, --> | ||
<!--- software distributed under the License is distributed on an --> | ||
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> | ||
<!--- KIND, either express or implied. See the License for the --> | ||
<!--- specific language governing permissions and limitations --> | ||
<!--- under the License. --> | ||
|
||
# How to Contribute | ||
|
||
We'd love to accept your patches and contributions to this project. There are | ||
just a few small guidelines you need to follow. | ||
|
||
## Contributor License Agreement | ||
|
||
Contributions to this project must be accompanied by a Contributor License | ||
Agreement. You (or your employer) retain the copyright to your contribution; | ||
this simply gives us permission to use and redistribute your contributions as | ||
part of the project. | ||
|
||
You generally only need to submit a CLA once, so if you've already submitted one | ||
(even if it was for a different project), you probably don't need to do it | ||
again. | ||
|
||
## Changes Accepted | ||
|
||
Please file issues before doing substantial work; this will ensure that others | ||
don't duplicate the work and that there's a chance to discuss any design issues. | ||
|
||
Changes only tweaking style are unlikely to be accepted unless they are applied | ||
consistently across the project. Most of the code style is derived from the | ||
[Google Style Guides](http://google.github.io/styleguide/) for the appropriate | ||
language and is generally not something we accept changes on (as clang-format | ||
and clang-tidy handle that for us). The compiler portion of the project follows | ||
[MLIR style](https://mlir.llvm.org/getting_started/DeveloperGuide/#style-guide). | ||
Improvements to code structure and clarity are welcome but please file issues to | ||
track such work first. | ||
|
||
## AUTHORS file | ||
|
||
If you would like to receive additional recognition for your contribution, you | ||
may add yourself (or your organization) to the AUTHORS file. This keeps track of | ||
those who have made significant contributions to the project. Please add the | ||
entity who owns the copyright for your contribution. The source control history | ||
remains the most accurate source for individual contributions. | ||
|
||
## Pull Requests | ||
We actively welcome your pull requests. | ||
|
||
1. Fork the repo and create your branch from `main`. | ||
2. If you've added code that should be tested, add tests. | ||
3. If you've changed APIs, update the documentation. | ||
4. Ensure the test suite passes. | ||
5. Make sure your code lints. | ||
6. If you haven't already, complete the Contributor License Agreement ("CLA"). | ||
|
||
## Issues | ||
|
||
We use GitHub issues to track public bugs. Please ensure your description is | ||
clear and has sufficient instructions to be able to reproduce the issue. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
include src/ths_op/*.cc.inc | ||
exclude pynvshmem/ | ||
recursive-include src * | ||
recursive-include include * | ||
recursive-include python/flux_ths_pybind |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.