Skip to content

Commit

Permalink
Support IPC && SM90 version of AG-GEMM, GEMM-RS (#9)
Browse files Browse the repository at this point in the history
* Support IPC && SM90 version of AG-GEMM, GEMM-RS

Simultaneously supports IPC and NVSHMEM, allowing users to
choose whether to enable NVSHMEM, and also supports two OPs
of the SM90 version. Besides, Update the README accordlingly
and add some performance data.

---------

Co-authored-by: Chengquan Jiang <imjcqt@gmail.com>
Co-authored-by: Wenlei Bao <wenlei.bao@bytedance.com>
Co-authored-by: Qi Hou <houqi1993@gmail.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
Co-authored-by: Xin Liu <liuxin.ai@bytedance.com>
Co-authored-by: Liwen Chang <liwen.chang@bytedance.com>
Co-authored-by: Haibin Lin <haibin.lin@bytedance.com>
  • Loading branch information
8 people authored Jun 27, 2024
1 parent 96b2e03 commit 775e061
Show file tree
Hide file tree
Showing 119 changed files with 10,236 additions and 20,580 deletions.
7 changes: 5 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(BUILD_TEST ON CACHE INTERNAL "Build unit tests")
set(ENABLE_NVSHMEM ON CACHE INTERNAL "Use NVSHMEM to transfer data")
set(CUTLASS_TRACE OFF CACHE INTERNAL "Print CUTLASS Host Trace info")
set(FLUX_DEBUG OFF CACHE INTERNAL "Define FLUX_DEBUG")
OPTION(WITH_PROTOBUF "build with protobuf" OFF)
message("PYTHONPATH: ${PYTHONPATH}")
message("NVShmem Support: ${ENABLE_NVSHMEM}")

Expand All @@ -21,6 +22,8 @@ if(CUDAToolkit_VERSION VERSION_LESS "11.0")
message(FATAL_ERROR "requires cuda to be >= 11.0")
elseif(CUDAToolkit_VERSION VERSION_LESS "12.0")
set(CUDAARCHS "80" CACHE STRING "CUDA Architectures")
elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.4")
set(CUDAARCHS "80;89;90" CACHE STRING "CUDA Architectures")
else()
set(CUDAARCHS "80;90" CACHE STRING "CUDA Architectures")
endif()
Expand Down Expand Up @@ -143,9 +146,9 @@ set(COMMON_HEADER_DIRS

set(COMMON_LIB_DIRS "")
list(APPEND COMMON_LIB_DIRS "${CUDAToolkit_LIBRARY_DIR}")

message(ENABLE_NVSHMEM "ENABLE_NVSHMEM is set to: ${ENABLE_NVSHMEM}")
if(ENABLE_NVSHMEM)
add_definitions(-DFLUX_USE_NVSHMEM)
add_definitions(-DFLUX_SHM_USE_NVSHMEM)
set(NVSHMEM_BUILD_DIR ${PROJECT_SOURCE_DIR}/3rdparty/nvshmem/build)
message(STATUS "NVSHMEM build dir: ${NVSHMEM_BUILD_DIR}")
if(NOT EXISTS ${NVSHMEM_BUILD_DIR})
Expand Down
45 changes: 35 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,45 @@ Flux significantly can reduce latency and increase throughput for tensor paralle

## Build
```bash
git clone https://github.com/bytedance/flux.git
git submodule update --init --recursive
# Ampere
./build.sh --arch 80
# Hopper
./build.sh --arch 90
```
## Build for cross-machine TP
FLUX relies on NVSHMEM for communication across nodes. Therefore, if you need support for cross-machine tensor parallelism (TP), you must manually download the NVSHMEM source code and enable the nvshmem option during compilation.

```bash
git clone https://github.com/bytedance/flux.git
# Download nvshmem-2.11(https://developer.nvidia.com/nvshmem) and place it to flux/3rdparty/nvshmem
# Flux is temporarily dependent on a specific version of nvshmem (2.11).
tar Jxvf nvshmem_src_2.11.0-5.txz
mv nvshmem_src_2.11.0-5 ${YOUR_PATH}/flux/3rdparty/nvshmem
git submodule update --init --recursive
# Ampere
./build.sh --arch 80

# Ampere
./build.sh --arch 80 --nvshmem
# Hopper
./build.sh --arch 90 --nvshmem
```

If you are tired of the cmake process, you can set environment variable `FLUX_BUILD_SKIP_CMAKE` to 1 to skip cmake if `build/CMakeCache.txt` already exists.

If you want to build a wheel package, add `--package` to the build command. find the output wheel file under dist/

```
```bash
# Ampere
./build.sh --arch 80 --package
```

For development release, run build script with `FLUX_FINAL_RELEASE=0`.

```
# Ampere
FLUX_FINAL_RELEASE=0 ./build.sh --arch 80 --package
# Hopper
./build.sh --arch 90 --package
```


## Run Demo
```
```bash
# gemm only
PYTHONPATH=./python:$PYTHONPATH python3 test/test_gemm_only.py 4096 12288 6144 --dtype=float16

Expand All @@ -47,6 +58,20 @@ PYTHONPATH=./python:$PYTHONPATH python3 test/test_gemm_only.py 4096 12288 6144 -
./scripts/launch.sh test/test_ag_kernel.py 4096 49152 12288 --dtype=float16 --iters=10
```

## Performance
We measured the examples from the above demo on both A800s and H800s. Each machine has 8 GPUs, with a TP size set to 8. The table below shows the performance comparison between flux and torch+nccl. It can be observed that by overlapping fine-grained computation and communication, Flux is able to effectively hide a significant portion of the communication time

| | M | K | N | Torch Gemm | Torch NCCL | Torch Total | Flux Gemm | Flux NCCL | Flux Total |
|----------|----------|----------|----------|----------|----------|----------|----------|----------|-----------|
| AG+Gemm(A800) | 4096 | 12288 | 49152 | 2.438ms | 0.662ms | 3.099ms | 2.378ms | 0.091ms | 2.469ms |
| Gemm+RS(A800) | 4096 | 49152 | 12288 | 2.453ms | 0.646ms | 3.100ms | 2.429ms | 0.080ms | 2.508ms |
| AG+Gemm(H800) | 4096 | 12288 | 49152 | 0.846ms | 0.583ms | 1.429ms | 0.814ms | 0.143ms | 0.957ms |
| Gemm+RS(H800) | 4096 | 49152 | 12288 | 0.818ms | 0.590ms | 1.408ms | 0.822ms | 0.111ms | 0.932ms |

AG refers to AllGather.
RS refers to ReduceScatter.


## Citing

If you use Flux in a scientific publication, we encourage you to add the following reference
Expand Down
106 changes: 75 additions & 31 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
set -x
set -e

export PATH=/usr/local/cuda/bin:$PATH
CMAKE=${CMAKE:-cmake}

ARCH=""
BUILD_TEST="ON"
BDIST_WHEEL="OFF"
WITH_PROTOBUF="OFF"
FLUX_DEBUG="OFF"
ENABLE_NVSHMEM="OFF"

function clean_py() {
rm -rf build/lib.*
Expand Down Expand Up @@ -52,11 +57,20 @@ while [[ $# -gt 0 ]]; do
;;
--debug)
FLUX_DEBUG="ON"
shift;;
shift
;;
--package)
BDIST_WHEEL="ON"
shift # Skip the argument key
;;
--protobuf)
WITH_PROTOBUF="ON"
shift
;;
--nvshmem)
ENABLE_NVSHMEM="ON"
shift
;;
*)
# Unknown argument
echo "Unknown argument: $1"
Expand All @@ -67,6 +81,7 @@ done

SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PROJECT_ROOT=${SCRIPT_DIR}
PROTOBUF_ROOT=$PROJECT_ROOT/3rdparty/protobuf

cd ${PROJECT_ROOT}

Expand All @@ -78,6 +93,20 @@ if [[ -z $JOBS ]]; then
JOBS=$(nproc --ignore 2)
fi

##### build protobuf #####
function build_protobuf() {
if [ $WITH_PROTOBUF == "ON" ]; then
pushd $PROTOBUF_ROOT
mkdir -p $PWD/build/local
pushd build
CFLAGS="-fPIC" CXXFLAGS="-fPIC" cmake ../cmake -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF -DCMAKE_INSTALL_PREFIX=$(realpath local)
make -j$(nproc)
make install
popd
popd
fi
}

function build_nccl() {
pushd $NCCL_ROOT
export BUILDDIR=${NCCL_ROOT}/build
Expand Down Expand Up @@ -108,6 +137,8 @@ function build_nccl() {

##### build nvshmem_bootstrap_torch #####
function build_pynvshmem() {
PYNVSHMEM_DIR=$PROJECT_ROOT/pynvshmem
export NVSHMEM_HOME=$PROJECT_ROOT/3rdparty/nvshmem/build/src
mkdir -p ${PYNVSHMEM_DIR}/build

pushd ${PYNVSHMEM_DIR}/build
Expand All @@ -126,11 +157,18 @@ function build_flux_cuda() {
pushd build
if [ ! -f CMakeCache.txt ] || [ -z ${FLUX_BUILD_SKIP_CMAKE} ]; then
CMAKE_ARGS=(
-DENABLE_NVSHMEM=on
-DENABLE_NVSHMEM=${ENABLE_NVSHMEM}
-DCUDAARCHS=${ARCH}
-DCMAKE_EXPORT_COMPILE_COMMANDS=1
-DBUILD_TEST=${BUILD_TEST}
)
if [ $WITH_PROTOBUF == "ON" ]; then
CMAKE_ARGS+=(
-DWITH_PROTOBUF=ON
-DProtobuf_ROOT=${PROTOBUF_ROOT}/build/local
-DProtobuf_PROTOC_EXECUTABLE=${PROTOBUF_ROOT}/build/local/bin/protoc
)
fi
if [ $FLUX_DEBUG == "ON" ]; then
CMAKE_ARGS+=(
-DFLUX_DEBUG=ON
Expand All @@ -142,28 +180,6 @@ function build_flux_cuda() {
popd
}

function build_flux_py {
LIBDIR=${PROJECT_ROOT}/python/lib
mkdir -p ${LIBDIR}

rm -f ${LIBDIR}/libflux_cuda.so
rm -f ${LIBDIR}/nvshmem_bootstrap_torch.so
rm -f ${LIBDIR}/nvshmem_transport_ibrc.so.2
rm -f ${LIBDIR}/libnvshmem_host.so.2
pushd ${LIBDIR}
cp -s ../../build/lib/libflux_cuda.so .
cp -s ../../pynvshmem/build/nvshmem_bootstrap_torch.so .
cp -s ../../3rdparty/nvshmem/build/src/lib/nvshmem_transport_ibrc.so.2 .
cp -s ../../3rdparty/nvshmem/build/src/lib/libnvshmem_host.so.2 .
popd

##### build flux torch bindings #####
MAX_JOBS=${JOBS} python3 setup.py develop --user
if [ $BDIST_WHEEL == "ON" ]; then
MAX_JOBS=${JOBS} python3 setup.py bdist_wheel
fi
}

function merge_compile_commands() {
if command -v ninja >/dev/null 2>&1; then
# generate compile_commands.json
Expand All @@ -185,17 +201,45 @@ EOF
fi
}

function build_flux_py {
LIBDIR=${PROJECT_ROOT}/python/lib
rm -rf ${LIBDIR}
mkdir -p ${LIBDIR}

# rm -f ${LIBDIR}/libflux_cuda.so
# rm -f ${LIBDIR}/nvshmem_bootstrap_torch.so
# rm -f ${LIBDIR}/nvshmem_transport_ibrc.so.2
# rm -f ${LIBDIR}/libnvshmem_host.so.2
pushd ${LIBDIR}
cp -s ../../build/lib/libflux_cuda.so .
if [ $ENABLE_NVSHMEM == "ON" ]; then
cp -s ../../pynvshmem/build/nvshmem_bootstrap_torch.so .
cp -s ../../3rdparty/nvshmem/build/src/lib/nvshmem_transport_ibrc.so.2 .
cp -s ../../3rdparty/nvshmem/build/src/lib/libnvshmem_host.so.2 .
export FLUX_SHM_USE_NVSHMEM=1
fi
popd
##### build flux torch bindings #####
MAX_JOBS=${JOBS} python3 setup.py develop --user
if [ $BDIST_WHEEL == "ON" ]; then
MAX_JOBS=${JOBS} python3 setup.py bdist_wheel
fi
merge_compile_commands
}

NCCL_ROOT=$PROJECT_ROOT/3rdparty/nccl
build_nccl

./build_nvshmem.sh ${build_args} --jobs ${JOBS}

export PATH=/usr/local/cuda/bin:$PATH
CMAKE=${CMAKE:-cmake}
PYNVSHMEM_DIR=$PROJECT_ROOT/pynvshmem
export NVSHMEM_HOME=$PROJECT_ROOT/3rdparty/nvshmem/build/src
if [ $ENABLE_NVSHMEM == "ON" ]; then
./build_nvshmem.sh ${build_args} --jobs ${JOBS}
fi

build_protobuf

if [ $ENABLE_NVSHMEM == "ON" ]; then
build_pynvshmem
fi

build_pynvshmem
build_flux_cuda
build_flux_py
merge_compile_commands
114 changes: 114 additions & 0 deletions gen_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import argparse
import os
import subprocess
from pathlib import Path
import shutil
import re
from typing import Optional, Tuple

CUR_DIR = os.path.dirname(os.path.realpath(__file__))


def _check_env_option(opt, default=""):
return os.getenv(opt, default).upper() in ["ON", "1", "YES", "TRUE"]


def check_final_release():
return _check_env_option("FLUX_FINAL_RELEASE", "1")


def get_git_commit(src_dir):
try:
return (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=src_dir)
.decode("ascii")
.strip()
)
except Exception:
return "unknown"


def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) by nvcc --version"""

# Try finding NVCC
nvcc_bin: Optional[Path] = None
if nvcc_bin is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME"))
nvcc_bin = cuda_home / "bin" / "nvcc"
if nvcc_bin is None:
# Check if nvcc is in path
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
nvcc_bin = Path(nvcc_bin)
if nvcc_bin is None:
# Last-ditch guess in /usr/local/cuda
cuda_home = Path("/usr/local/cuda")
nvcc_bin = cuda_home / "bin" / "nvcc"
if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")

# Query NVCC for version info
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split(".")
return tuple(int(v) for v in version)


def get_flux_version(version_txt, *, dev=False):
with open(version_txt) as f:
version = f.readline()
version = version.strip()
cuda_version_major, cuda_version_minor = cuda_version()
version = version + f"+cu{cuda_version_major}{cuda_version_minor}"
if dev:
commit_id = get_git_commit(CUR_DIR)

version += ".dev{}".format(commit_id[:8])
# version = version + (f'.{os.getenv("ARCH")}' if os.getenv("ARCH") else "")
return version


def generate_versoin_file(version_txt, version_file, *, dev=False):
flux_ver = get_flux_version(version_txt, dev=dev)

with open(version_file, "w") as f:
f.write("__version__ = '{}'\n".format(flux_ver))
f.write("git_version = {}\n".format(repr(get_git_commit(CUR_DIR))))
cuda_version_major, cuda_version_minor = cuda_version()
f.write("cuda = {}.{}\n".format(cuda_version_major, cuda_version_minor))

return flux_ver


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="generate version.py")
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--dev", action="store_true")
args = parser.parse_args()

generate_versoin_file(args.input, args.output, dev=args.dev)
Loading

0 comments on commit 775e061

Please sign in to comment.