Skip to content

Commit

Permalink
[Neuron] Add an option to build with neuron (vllm-project#2065)
Browse files Browse the repository at this point in the history
  • Loading branch information
liangfu authored Jan 18, 2024
1 parent 4df417d commit 18473cf
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 13 deletions.
9 changes: 9 additions & 0 deletions requirements-neuron.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
sentencepiece # Required for LLaMA tokenizer.
numpy
transformers-neuronx >= 0.9.0
torch-neuronx >= 2.1.0
neuronx-cc
fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette]
62 changes: 51 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,17 @@ def _is_hip() -> bool:
return torch.version.hip is not None


def _is_neuron() -> bool:
torch_neuronx_installed = True
try:
subprocess.run(["neuron-ls"], capture_output=True, check=True)
except FileNotFoundError as e:
torch_neuronx_installed = False
return torch_neuronx_installed


def _is_cuda() -> bool:
return torch.version.cuda is not None
return (torch.version.cuda is not None) and not _is_neuron()


# Compiler flags.
Expand Down Expand Up @@ -87,6 +96,24 @@ def get_hipcc_rocm_version():
return None


def get_neuronxcc_version():
import sysconfig
site_dir = sysconfig.get_paths()["purelib"]
version_file = os.path.join(site_dir, "neuronxcc", "version", "__init__.py")

# Check if the command was executed successfully
with open(version_file, "rt") as fp:
content = fp.read()

# Extract the version using a regular expression
match = re.search(r"__version__ = '(\S+)'", content)
if match:
# Return the version string
return match.group(1)
else:
raise RuntimeError("Could not find HIP version in the output")


def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"""Get the CUDA version from nvcc.
Expand Down Expand Up @@ -210,6 +237,9 @@ def get_torch_arch_list() -> Set[str]:
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
f"amdgpu_arch_found: {amd_arch}")

elif _is_neuron():
neuronxcc_version = get_neuronxcc_version()

ext_modules = []

vllm_extension_sources = [
Expand All @@ -227,15 +257,16 @@ def get_torch_arch_list() -> Set[str]:
if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")

vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(vllm_extension)
if not _is_neuron():
vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(vllm_extension)


def get_path(*filepath) -> str:
Expand Down Expand Up @@ -264,6 +295,12 @@ def get_vllm_version() -> str:
if hipcc_version != MAIN_CUDA_VERSION:
rocm_version_str = hipcc_version.replace(".", "")[:3]
version += f"+rocm{rocm_version_str}"
elif _is_neuron():
# Get the Neuron version
neuron_version = str(neuronxcc_version)
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}"
else:
cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION:
Expand All @@ -287,6 +324,9 @@ def get_requirements() -> List[str]:
if _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
elif _is_neuron():
with open(get_path("requirements-neuron.txt")) as f:
requirements = f.read().strip().split("\n")
else:
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
Expand Down Expand Up @@ -325,6 +365,6 @@ def get_requirements() -> List[str]:
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
package_data=package_data,
)
6 changes: 4 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import psutil
import torch

from vllm._C import cuda_utils


class Device(enum.Enum):
GPU = enum.auto()
Expand All @@ -36,6 +34,10 @@ def is_hip() -> bool:

def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since
# the Neuron-X backend does not have the `cuda_utils` module.
from vllm._C import cuda_utils

# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
max_shared_mem = cuda_utils.get_device_attribute(
Expand Down

0 comments on commit 18473cf

Please sign in to comment.