Skip to content

Commit

Permalink
Improve setup script & Add a guard for bfloat16 kernels (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored May 27, 2023
1 parent 4a151dd commit d721168
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 16 deletions.
3 changes: 0 additions & 3 deletions csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,4 @@
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"

#ifdef ENABLE_BF16
#include "dtype_bfloat16.cuh"
#endif // ENABLE_BF16
2 changes: 0 additions & 2 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,8 @@ void single_query_cached_kv_attention(
// TODO(woosuk): Support FP32.
if (query.dtype() == at::ScalarType::Half) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
#ifdef ENABLE_BF16
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
#endif
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
Expand Down
44 changes: 44 additions & 0 deletions csrc/attention/dtype_bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,36 @@ struct FloatVec<bf16_8_t> {

// Utility functions for type conversions.
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __bfloat1622float2(val);
#endif
}

inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __bfloat162bfloat162(val);
#endif
}

// Vector addition.
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return a + b;
#endif
}

inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hadd2(a, b);
#endif
}

inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
Expand Down Expand Up @@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
// Vector multiplication.
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hmul(a, b);
#endif
}

template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hmul2(a, b);
#endif
}

template<>
Expand Down Expand Up @@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {

// Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hfma2(a, b, c);
#endif
}

inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hfma2(bf162bf162(a), b, c);
#endif
}

inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
Expand Down Expand Up @@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
}

inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst = __float22bfloat162_rn(src);
#endif
}

inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
#endif
}

inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
#endif
}

} // namespace cacheflow
57 changes: 46 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,63 @@
from typing import List
import subprocess
from typing import List, Set

from packaging.version import parse, Version
import setuptools
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME


# Build custom operators.
CXX_FLAGS = ["-g"]
# Compiler flags.
CXX_FLAGS = ["-g", "-O2"]
# TODO(woosuk): Should we use -O3?
NVCC_FLAGS = ["-O2"]


if not torch.cuda.is_available():
raise RuntimeError(
f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. "
"CUDA must be available in order to build the package.")

# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
# different compute capabilities.
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
# Enable bfloat16 support if the compute capability is >= 8.0.
if major >= 8:
NVCC_FLAGS.append("-DENABLE_BF16")

def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"""Get the CUDA version from nvcc.
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
return nvcc_cuda_version


# Collect the compute capabilities of all available GPUs.
device_count = torch.cuda.device_count()
compute_capabilities: Set[int] = set()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 7:
raise RuntimeError(
"GPUs with compute capability less than 7.0 are not supported.")
compute_capabilities.add(major * 10 + minor)
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80, 86, 90}
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]

# Validate the NVCC CUDA version.
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if nvcc_cuda_version < Version("11.0"):
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")

ext_modules = []

Expand Down

0 comments on commit d721168

Please sign in to comment.