Skip to content

Commit

Permalink
Speed up Punica compilation (vllm-project#2632)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jan 28, 2024
1 parent 9297675 commit 2580ea5
Show file tree
Hide file tree
Showing 21 changed files with 100 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
steps:
- label: ":docker: build image"
commands:
- "docker build --tag {{ docker_image }} --target test --progress plain ."
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
- "docker push {{ docker_image }}"
env:
DOCKER_BUILDKIT: "1"
Expand Down
21 changes: 0 additions & 21 deletions csrc/punica/bgmv/bgmv_all.cu

This file was deleted.

4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
27 changes: 27 additions & 0 deletions csrc/punica/bgmv/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
DTYPES = ["fp16", "bf16", "fp32"]
DTYPE_MAP = {
"fp16": "nv_half",
"bf16": "nv_bfloat16",
"fp32": "float",
}

TEMPLATE = """
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip()

for input_dtype in DTYPES:
for output_dtype in DTYPES:
for weight_dtype in DTYPES:
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],
weight_dtype=DTYPE_MAP[weight_dtype])
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
with open(filename, "w") as f:
f.write(kernel_definition)

0 comments on commit 2580ea5

Please sign in to comment.