Skip to content

Commit 390b495

Browse files
authored
Don't build punica kernels by default (#2605)
1 parent 3a0e1fc commit 390b495

File tree

4 files changed

+11
-4
lines changed

4 files changed

+11
-4
lines changed

.github/workflows/scripts/build.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ $python_executable -m pip install -r requirements.txt
1313

1414
# Limit the number of parallel jobs to avoid OOM
1515
export MAX_JOBS=1
16+
# Make sure punica is built for the release (for LoRA)
17+
export VLLM_INSTALL_PUNICA_KERNELS=1
1618

1719
# Build
1820
$python_executable setup.py bdist_wheel --dist-dir=dist

Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ ENV MAX_JOBS=${max_jobs}
4545
# number of threads used by nvcc
4646
ARG nvcc_threads=8
4747
ENV NVCC_THREADS=$nvcc_threads
48+
# make sure punica kernels are built (for LoRA)
49+
ENV VLLM_INSTALL_PUNICA_KERNELS=1
4850

4951
RUN python3 setup.py build_ext --inplace
5052
#################### EXTENSION Build IMAGE ####################

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def get_torch_arch_list() -> Set[str]:
265265
with contextlib.suppress(ValueError):
266266
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
267267

268-
install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "1")))
268+
install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
269269
device_count = torch.cuda.device_count()
270270
for i in range(device_count):
271271
major, minor = torch.cuda.get_device_capability(i)

vllm/lora/punica.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,13 @@ def _raise_exc(
157157
**kwargs # pylint: disable=unused-argument
158158
):
159159
if torch.cuda.get_device_capability() < (8, 0):
160-
raise ImportError(
161-
"LoRA kernels require compute capability>=8.0") from import_exc
160+
raise ImportError("punica LoRA kernels require compute "
161+
"capability>=8.0") from import_exc
162162
else:
163-
raise import_exc
163+
raise ImportError(
164+
"punica LoRA kernels could not be imported. If you built vLLM "
165+
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
166+
"was set.") from import_exc
164167

165168
bgmv = _raise_exc
166169
add_lora = _raise_exc

0 commit comments

Comments
 (0)