Skip to content

Commit

Permalink
Update setup.py (vllm-project#1006)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Sep 11, 2023
1 parent b9cecc2 commit d6770d1
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
"Cannot find CUDA_HOME. CUDA must be available to build the package.")


def get_nvcc_cuda_version(cuda_dir: str) -> Version:
Expand Down Expand Up @@ -54,7 +54,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
)
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
Expand All @@ -65,7 +66,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
compute_capabilities.add(80)
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)

# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
Expand All @@ -78,7 +80,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:

# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
NVCC_FLAGS += [
"-gencode", f"arch=compute_{capability},code=sm_{capability}"
]

# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
Expand All @@ -91,39 +95,54 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
cache_extension = CUDAExtension(
name="vllm.cache_ops",
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(cache_extension)

# Attention kernels.
attention_extension = CUDAExtension(
name="vllm.attention_ops",
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(attention_extension)

# Positional encoding kernels.
positional_encoding_extension = CUDAExtension(
name="vllm.pos_encoding_ops",
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(positional_encoding_extension)

# Layer normalization kernels.
layernorm_extension = CUDAExtension(
name="vllm.layernorm_ops",
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(layernorm_extension)

# Activation kernels.
activation_extension = CUDAExtension(
name="vllm.activation_ops",
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(activation_extension)

Expand All @@ -138,8 +157,8 @@ def find_version(filepath: str):
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
"""
with open(filepath) as fp:
version_match = re.search(
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
fp.read(), re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
Expand All @@ -162,7 +181,8 @@ def get_requirements() -> List[str]:
version=find_version(get_path("vllm", "__init__.py")),
author="vLLM Team",
license="Apache 2.0",
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
description=("A high-throughput and memory-efficient inference and "
"serving engine for LLMs"),
long_description=read_readme(),
long_description_content_type="text/markdown",
url="https://github.com/vllm-project/vllm",
Expand All @@ -174,11 +194,12 @@ def get_requirements() -> List[str]:
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
packages=setuptools.find_packages(
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
"examples", "tests")),
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
Expand Down

0 comments on commit d6770d1

Please sign in to comment.