Skip to content

Commit 996ddd9

Browse files
committed
Merge branch 'main' into mxfp6_mixed
2 parents 8ef7e00 + fc67969 commit 996ddd9

File tree

19 files changed

+1580
-209
lines changed

19 files changed

+1580
-209
lines changed

.buildkite/pyproject.toml

Lines changed: 0 additions & 46 deletions
This file was deleted.
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from __future__ import annotations
4+
5+
import random
6+
import time
7+
8+
import torch
9+
from tabulate import tabulate
10+
11+
from vllm import _custom_ops as ops
12+
from vllm.logger import init_logger
13+
from vllm.platforms import current_platform
14+
from vllm.utils import (
15+
STR_DTYPE_TO_TORCH_DTYPE,
16+
FlexibleArgumentParser,
17+
create_kv_caches_with_random,
18+
)
19+
20+
logger = init_logger(__name__)
21+
22+
23+
@torch.inference_mode()
24+
def run_benchmark(
25+
num_tokens: int,
26+
num_heads: int,
27+
head_size: int,
28+
block_size: int,
29+
num_blocks: int,
30+
dtype: torch.dtype,
31+
kv_cache_dtype: str,
32+
num_iters: int,
33+
benchmark_mode: str,
34+
device: str = "cuda",
35+
) -> float:
36+
"""Return latency (seconds) for given num_tokens."""
37+
38+
if kv_cache_dtype == "fp8" and head_size % 16:
39+
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
40+
41+
current_platform.seed_everything(42)
42+
torch.set_default_device(device)
43+
44+
# create random key / value tensors [T, H, D].
45+
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
46+
value = torch.randn_like(key)
47+
48+
# prepare the slot mapping.
49+
# each token is assigned a unique slot in the KV-cache.
50+
num_slots = block_size * num_blocks
51+
if num_tokens > num_slots:
52+
raise ValueError("num_tokens cannot exceed the total number of cache slots")
53+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
54+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
55+
56+
key_caches, value_caches = create_kv_caches_with_random(
57+
num_blocks,
58+
block_size,
59+
1, # num_layers
60+
num_heads,
61+
head_size,
62+
kv_cache_dtype,
63+
dtype,
64+
device=device,
65+
)
66+
key_cache, value_cache = key_caches[0], value_caches[0]
67+
# to free unused memory
68+
del key_caches, value_caches
69+
70+
# compute per-kernel scaling factors for fp8 conversion (if used).
71+
k_scale = (key.amax() / 64.0).to(torch.float32)
72+
v_scale = (value.amax() / 64.0).to(torch.float32)
73+
74+
function_under_test = lambda: ops.reshape_and_cache(
75+
key, # noqa: F821
76+
value, # noqa: F821
77+
key_cache, # noqa: F821
78+
value_cache, # noqa: F821
79+
slot_mapping, # noqa: F821
80+
kv_cache_dtype,
81+
k_scale,
82+
v_scale,
83+
)
84+
85+
if benchmark_mode == "cudagraph":
86+
g = torch.cuda.CUDAGraph()
87+
with torch.cuda.graph(g):
88+
function_under_test()
89+
torch.cuda.synchronize()
90+
function_under_test = lambda: g.replay()
91+
92+
def run_cuda_benchmark(n_iters: int) -> float:
93+
nonlocal key, value, key_cache, value_cache, slot_mapping
94+
torch.cuda.synchronize()
95+
start = time.perf_counter()
96+
for _ in range(n_iters):
97+
function_under_test()
98+
torch.cuda.synchronize()
99+
end = time.perf_counter()
100+
return (end - start) / n_iters
101+
102+
# warm-up
103+
run_cuda_benchmark(3)
104+
105+
lat = run_cuda_benchmark(num_iters)
106+
107+
# free tensors to mitigate OOM when sweeping
108+
del key, value, key_cache, value_cache, slot_mapping
109+
torch.cuda.empty_cache()
110+
111+
return lat
112+
113+
114+
def main(args):
115+
rows = []
116+
for exp in range(1, 17):
117+
n_tok = 2**exp
118+
lat = run_benchmark(
119+
num_tokens=n_tok,
120+
num_heads=args.num_heads,
121+
head_size=args.head_size,
122+
block_size=args.block_size,
123+
num_blocks=args.num_blocks,
124+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
125+
kv_cache_dtype=args.kv_cache_dtype,
126+
num_iters=args.iters,
127+
benchmark_mode=args.mode,
128+
device="cuda",
129+
)
130+
rows.append([n_tok, lat * 1e6]) # convert to microseconds
131+
132+
print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
133+
print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f"))
134+
135+
136+
if __name__ == "__main__":
137+
parser = FlexibleArgumentParser()
138+
139+
parser.add_argument("--num-heads", type=int, default=128)
140+
parser.add_argument(
141+
"--head-size",
142+
type=int,
143+
choices=[64, 80, 96, 112, 120, 128, 192, 256],
144+
default=128,
145+
)
146+
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
147+
parser.add_argument("--num-blocks", type=int, default=128 * 128)
148+
149+
parser.add_argument(
150+
"--dtype",
151+
type=str,
152+
choices=["half", "bfloat16", "float"],
153+
default="bfloat16",
154+
)
155+
156+
parser.add_argument(
157+
"--kv-cache-dtype",
158+
type=str,
159+
choices=["auto", "fp8"],
160+
default="auto",
161+
)
162+
163+
parser.add_argument("--iters", type=int, default=200)
164+
165+
parser.add_argument(
166+
"--mode",
167+
type=str,
168+
choices=["cudagraph", "no_graph"],
169+
default="cudagraph",
170+
)
171+
172+
args = parser.parse_args()
173+
174+
main(args)

benchmarks/pyproject.toml

Lines changed: 0 additions & 49 deletions
This file was deleted.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# KAITO
2+
3+
[KAITO](https://kaito-project.github.io/kaito/docs/) is a Kubernetes operator that supports deploying and serving LLMs with vLLM. It offers managing large models via container images with built-in OpenAI-compatible inference, auto-provisioning GPU nodes and curated model presets.
4+
5+
Please refer to [quick start](https://kaito-project.github.io/kaito/docs/quick-start) for more details.

examples/pyproject.toml

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)