Skip to content

Commit

Permalink
CPU only build (vllm-project#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
maktukmak authored and bigPYJ1151 committed Jan 31, 2024
1 parent e811c70 commit ad25d65
Show file tree
Hide file tree
Showing 17 changed files with 1,143 additions and 55 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ sanitizer:
py_install:
VLLM_BUILD_CPU_OPS=1 MAX_JOBS=JOBS pip install --no-build-isolation -v -e .

py_install_cpu:
VLLM_BUILD_CPU_ONLY=1 MAX_JOBS=JOBS pip install --no-build-isolation -v -e .

install_vllm:
MAX_JOBS=JOBS pip install -v git+https://github.com/intel-sandbox/vllm-xpu.git@dev -f https://download.pytorch.org/whl/torch_stable.html

package:
VLLM_BUILD_CPU_OPS=1 MAX_JOBS=JOBS python setup.py bdist_wheel
echo "Wheel package is saved in ./dist/"
Expand Down
77 changes: 77 additions & 0 deletions cpu.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
FROM python:3.10 AS dev

RUN apt-get update -y \
&& apt-get install -y python3-pip

WORKDIR /workspace

# install build and runtime dependencies
COPY requirements-cpu.txt requirements-cpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-cpu.txt

# install development dependencies
COPY requirements-dev.txt requirements-dev.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-dev.txt

# image to build pytorch extensions
FROM dev AS build

# install build dependencies
COPY requirements-build-cpu.txt requirements-build-cpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-build-cpu.txt

# copy input files
COPY csrc csrc
COPY setup.py setup.py
COPY requirements-cpu.txt requirements-cpu.txt
COPY pyproject.toml pyproject.toml
COPY vllm/__init__.py vllm/__init__.py

# max jobs used by Ninja to build extensions
ENV MAX_JOBS=$max_jobs
RUN python3 setup.py build_ext --inplace

# image to run unit testing suite
FROM dev AS test

# copy pytorch extensions separately to avoid having to rebuild
# when python code changes
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY tests tests
COPY vllm vllm

ENTRYPOINT ["python3", "-m", "pytest", "tests"]

# use CUDA base as CUDA runtime dependencies are already installed via pip
FROM python:3.10 AS dev

# libnccl required for ray
RUN apt-get update -y \
&& apt-get install -y python3-pip

WORKDIR /workspace
COPY requirements-cpu.txt requirements-cpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-cpu.txt

FROM vllm-base AS vllm
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm

EXPOSE 8000
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]

# openai api server alternative
FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate fschat

COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]

4 changes: 4 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#ifdef VLLM_BUILD_CPU_ONLY
#define VLLM_DISPATCH_TO_CUDA_CASE(BASENAME, ...)
#else
#define VLLM_DISPATCH_TO_CUDA_CASE(BASENAME, ...) \
case c10::DeviceType::CUDA: { \
return BASENAME(__VA_ARGS__); \
}
#endif

#ifdef VLLM_BUILD_CPU_OPS
#define VLLM_DISPATCH_TO_CPU_CASE(BASENAME, ...) \
Expand Down
6 changes: 6 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ void gptq_shuffle_dispatch(
VLLM_DISPATCH_DEVICES(q_weight.device(), gptq_shuffle, q_weight, q_perm);
}

#ifdef VLLM_BUILD_CPU_ONLY
int get_device_attribute(
int attribute,
int device_id) { return 94387; }
#endif

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
Expand Down
14 changes: 7 additions & 7 deletions Dockerfile → gpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ RUN apt-get update -y \
WORKDIR /workspace

# install build and runtime dependencies
COPY requirements.txt requirements.txt
COPY requirements-gpu.txt requirements-gpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements.txt
pip install -r requirements-gpu.txt

# install development dependencies
COPY requirements-dev.txt requirements-dev.txt
Expand All @@ -25,14 +25,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \
FROM dev AS build

# install build dependencies
COPY requirements-build.txt requirements-build.txt
COPY requirements-build-gpu.txt requirements-build-gpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-build.txt
pip install -r requirements-build-gpu.txt

# copy input files
COPY csrc csrc
COPY setup.py setup.py
COPY requirements.txt requirements.txt
COPY requirements-gpu.txt requirements-gpu.txt
COPY pyproject.toml pyproject.toml
COPY vllm/__init__.py vllm/__init__.py

Expand Down Expand Up @@ -75,9 +75,9 @@ RUN apt-get update -y \
&& apt-get install -y python3-pip

WORKDIR /workspace
COPY requirements.txt requirements.txt
COPY requirements-gpu.txt requirements-gpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements.txt
pip install -r requirements-gpu.txt
#################### RUNTIME BASE IMAGE ####################


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = [
"ninja",
"packaging",
"setuptools >= 49.4.0",
"torch == 2.1.2",
"torch == 2.1.2+cpu",
"wheel",
]
build-backend = "setuptools.build_meta"
Expand Down
6 changes: 6 additions & 0 deletions requirements-build-cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Should be mirrored in pyproject.toml
ninja
packaging
setuptools>=49.4.0
torch==2.1.2+cpu
wheel
File renamed without changes.
15 changes: 15 additions & 0 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
ninja # For faster builds.
psutil
ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
pybind11
sentencepiece # Required for LLaMA tokenizer.
numpy
einops # Required for phi-1_5
torch == 2.1.2+cpu
transformers >= 4.34.0 # Required for Mistral.
fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette]
File renamed without changes.
95 changes: 66 additions & 29 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from packaging.version import parse, Version
import setuptools
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME

BUILD_CPU_ONLY = os.getenv('VLLM_BUILD_CPU_ONLY', "1") == "1"

if not BUILD_CPU_ONLY:
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
else:
from torch.utils.cpp_extension import BuildExtension, CppExtension

ROOT_DIR = os.path.dirname(__file__)

Expand All @@ -21,11 +27,11 @@


def _is_hip() -> bool:
return torch.version.hip is not None
return torch.version.hip is not None and not BUILD_CPU_ONLY


def _is_cuda() -> bool:
return torch.version.cuda is not None
return torch.version.cuda is not None and not BUILD_CPU_ONLY


# Compiler flags.
Expand Down Expand Up @@ -86,7 +92,6 @@ def get_hipcc_rocm_version():
print("Could not find HIP version in the output")
return None


def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"""Get the CUDA version from nvcc.
Expand Down Expand Up @@ -137,6 +142,19 @@ def get_torch_arch_list() -> Set[str]:
stacklevel=2)
return arch_list

if not BUILD_CPU_ONLY:
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 7:
raise RuntimeError(
"GPUs with compute capability below 7.0 are not supported.")
compute_capabilities.add(f"{major}.{minor}")

# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
Expand Down Expand Up @@ -211,9 +229,11 @@ def get_torch_arch_list() -> Set[str]:
f"amdgpu_arch_found: {amd_arch}")

# Setup CPU Operations
BUILD_CPU_OPS = os.getenv('VLLM_BUILD_CPU_OPS', "0") == "1"
BUILD_CPU_OPS = (os.getenv('VLLM_BUILD_CPU_OPS', "0") == "1" or BUILD_CPU_ONLY)
CPU_OPS_SOURCES = []
if BUILD_CPU_OPS:
if BUILD_CPU_ONLY:
CXX_FLAGS += ["-DVLLM_BUILD_CPU_ONLY"]
CXX_FLAGS += [
"-DVLLM_BUILD_CPU_OPS", "-fopenmp", "-mavx512f", "-mavx512bf16",
"-mavx512vl"
Expand All @@ -228,29 +248,42 @@ def get_torch_arch_list() -> Set[str]:

ext_modules = []

vllm_extension_sources = [
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu",
"csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp",
] + CPU_OPS_SOURCES
if not BUILD_CPU_ONLY:
vllm_extension_sources = [
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu",
"csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp",
] + CPU_OPS_SOURCES

if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")

vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
else:
vllm_extension_sources = [
"csrc/pybind.cpp",
] + CPU_OPS_SOURCES
vllm_extension = CppExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
},
)

if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")

vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(vllm_extension)


Expand Down Expand Up @@ -280,7 +313,7 @@ def get_vllm_version() -> str:
if hipcc_version != MAIN_CUDA_VERSION:
rocm_version_str = hipcc_version.replace(".", "")[:3]
version += f"+rocm{rocm_version_str}"
else:
elif _is_cuda():
cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
Expand All @@ -303,9 +336,13 @@ def get_requirements() -> List[str]:
if _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
elif _is_cuda():
with open(get_path("requirements-gpu.txt")) as f:
requirements = f.read().strip().split("\n")
else:
with open(get_path("requirements.txt")) as f:
with open(get_path("requirements-cpu.txt")) as f:
requirements = f.read().strip().split("\n")

return requirements


Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class SiluAndMul(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
return (F.silu(x[..., :d].float()) * x[..., d:].float()).to(x)

def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
Expand Down
23 changes: 14 additions & 9 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@

import torch
import torch.nn as nn
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
try:
from xformers import ops as xops
except:
pass

try:
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
except:
from vllm.model_executor.layers.xformers_cpu.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)

from vllm._C import ops
from vllm._C import cache_ops
Expand Down Expand Up @@ -160,12 +168,9 @@ def forward(
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
) if not self.cpu_only else torch.nn.functional.scaled_dot_product_attention(
query.movedim(1,
query.dim() -
2), key.movedim(1,
query.dim() - 2),
value.movedim(1,
value.dim() - 2), input_metadata.attn_bias,
query.movedim(1, query.dim() -2), key.movedim(1, query.dim() - 2),
value.movedim(1, value.dim() - 2),
input_metadata.attn_bias,
0.0).movedim(query.dim() - 2, 1).contiguous()
output = out.view_as(query)
else:
Expand Down
Loading

0 comments on commit ad25d65

Please sign in to comment.