Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
f508c4a
KV cache memory pool on host
xiezhq-hermann Nov 27, 2024
0639ff5
hierarchical cache controller
xiezhq-hermann Nov 27, 2024
7de8fc5
radix tree for hierarchical cache
xiezhq-hermann Nov 27, 2024
217a9b2
minimal change to plug in hierarchical cache
xiezhq-hermann Nov 27, 2024
09caf1c
remove duplicated code
xiezhq-hermann Nov 27, 2024
2d57a87
hierarchiccal cache micro-benchmark
xiezhq-hermann Nov 27, 2024
2007d1d
global CUDA synchronization to prevent illegal memory access
xiezhq-hermann Dec 19, 2024
d8b6b64
write through and back policies, deprecate write through revokable
xiezhq-hermann Dec 19, 2024
8e68f71
minor change on scheduler for hierarchical caching
xiezhq-hermann Dec 19, 2024
ff19db0
fix rebase error
xiezhq-hermann Dec 27, 2024
1846672
fix counter reset issue
xiezhq-hermann Dec 27, 2024
2ff5297
bug fix for illegal memory access in pytorch indexing and reduce data…
xiezhq-hermann Dec 29, 2024
7bf5ff4
draft multi turn benchmark
xiezhq-hermann Dec 31, 2024
89fc497
reorg multiturn benchmark
xiezhq-hermann Jan 1, 2025
191a02d
bug fix for scheduler
xiezhq-hermann Jan 1, 2025
81e39dc
reorg test
xiezhq-hermann Jan 1, 2025
1853cf2
fupdate format
xiezhq-hermann Jan 1, 2025
3a1e602
bug fix for device of value
xiezhq-hermann Jan 6, 2025
64ad5dc
introduce protected size for better memory check
xiezhq-hermann Jan 7, 2025
5e91291
style change for host memory pool
xiezhq-hermann Jan 7, 2025
80ffe49
move debug_timing to utils
xiezhq-hermann Jan 7, 2025
26be1b3
clean up the cache controller
xiezhq-hermann Jan 9, 2025
51c9537
port for bench order
xiezhq-hermann Jan 9, 2025
cba5e82
separated file for hiradix tree
xiezhq-hermann Jan 9, 2025
4dc4fcd
new workload generator for multi turn benchmark
xiezhq-hermann Jan 14, 2025
f8d0826
protected size for base radix
xiezhq-hermann Jan 16, 2025
d965e1b
make the memory quota of PrefillAdder dynamic
xiezhq-hermann Jan 16, 2025
de043d8
concurrency bug fix
xiezhq-hermann Jan 17, 2025
aaba20e
refactoring hiradix
xiezhq-hermann Jan 17, 2025
0f072c8
minor change on multi turn benchmark
xiezhq-hermann Jan 17, 2025
3c1219e
introduce operation split
xiezhq-hermann Jan 17, 2025
9c642e9
reduce the priority of loading data
xiezhq-hermann Jan 17, 2025
15f3337
Merge branch 'main_origin' into xiezhq-hierarchical
xiezhq-hermann Jan 17, 2025
3f451dc
nic formatting
xiezhq-hermann Jan 17, 2025
17ce0d2
Removing device sync overhead (#3011)
Edenzzzz Jan 20, 2025
c25e0a0
bug fix for loading
xiezhq-hermann Jan 22, 2025
d39160c
bug fix for loading cache in scheduling
xiezhq-hermann Jan 24, 2025
ebbed14
dedup code
xiezhq-hermann Jan 24, 2025
47ad482
mark staging requests
xiezhq-hermann Jan 24, 2025
89b4db8
multi-turn benchmark refinement
xiezhq-hermann Jan 26, 2025
b546141
new overlaping for write through and graceful reset for cache controller
xiezhq-hermann Jan 27, 2025
ff328fc
Merge branch 'main_origin' into xiezhq-hierarchical
xiezhq-hermann Jan 27, 2025
97a3c18
sanity check to prevent performance regression
xiezhq-hermann Jan 27, 2025
349a982
Merge branch 'xiezhq-check' into xiezhq-hierarchical
xiezhq-hermann Jan 27, 2025
6c39cb7
clean up and brief doc
xiezhq-hermann Jan 27, 2025
727f779
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Jan 27, 2025
691b7e0
format
xiezhq-hermann Jan 27, 2025
64fad0f
add log file name
xiezhq-hermann Jan 28, 2025
fd47928
Merge branch 'main_origin' into xiezhq-hierarchical
xiezhq-hermann Jan 28, 2025
e6d8ec8
nic cleaning
xiezhq-hermann Jan 28, 2025
fcf2e8d
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Jan 28, 2025
fe550c6
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Jan 28, 2025
c31fdc1
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Jan 30, 2025
ba7e737
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 1, 2025
9c2d0a9
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 3, 2025
c7fc0fd
style change
xiezhq-hermann Feb 3, 2025
512d8af
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 9, 2025
3299dcb
Merge branch 'main' into xiezhq-hierarchical
zhyncs Feb 14, 2025
5dc677f
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 15, 2025
16de196
Merge branch 'main' into xiezhq-hierarchical
zhyncs Feb 15, 2025
481c912
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 15, 2025
46dea61
Merge branch 'main' into xiezhq-hierarchical
zhyncs Feb 17, 2025
257e256
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 17, 2025
af62c95
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 19, 2025
996413c
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 20, 2025
5d714a8
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 22, 2025
a37244a
Merge branch 'main' into xiezhq-hierarchical
xiezhq-hermann Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions benchmark/hicache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## Run synthetic multi-turn benchmark

```
# SGLang server with radix cache disabled
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --disable-radix-cache

# SGLang server with radix cache on and first-come-first-serve policy
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --schedule-policy fcfs

# The default SGLang server with radix cache on and long-prefix-match policy
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000

# SGLang server with hierarchical radix cache enabled
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --enable-hierarchical-cache

```

```
python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct
```

Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.


## More benchmarks to be added
105 changes: 80 additions & 25 deletions benchmark/hicache/bench_multiturn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import threading
import time
from datetime import datetime
from typing import Optional

import aiohttp
Expand All @@ -26,9 +27,15 @@ def parse_args():
parser.add_argument(
"--num-clients",
type=int,
default=200,
default=256,
help="Number of concurrent clients",
)
parser.add_argument(
"--max-parallel",
type=int,
default=128,
help="Maximum number of parallel requests",
)
parser.add_argument(
"--request-length",
type=int,
Expand Down Expand Up @@ -73,11 +80,17 @@ def parse_args():
help="Server port (default: 30000)",
)
parser.add_argument(
"--model",
"--model-path",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="model path compatible with Hugging Face Transformers",
)
parser.add_argument(
"--log-file",
type=str,
default="performance_metrics.jsonl",
help="File to log performance metrics",
)
return parser.parse_args()


Expand Down Expand Up @@ -158,6 +171,18 @@ def gen_payload(prompt, output_len):
return payload


def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"):
"""Append the data with a timestamp to the specified JSONL file."""
timestamped_data = {"timestamp": datetime.now().isoformat(), **data}
try:
with open(file_path, "a") as file:
file.write(
json.dumps(timestamped_data) + "\n"
) # Write as a single line in JSONL format
except IOError as e:
print(f"Error writing to JSONL file: {e}")


class ReadyQueue:
"""
Thread-safe queue that can pop requests in different orders based on given policy.
Expand Down Expand Up @@ -191,12 +216,15 @@ def __init__(self, args):
# Construct the base URL for requests
self.url = f"http://{args.host}:{args.port}/generate"

self.tokenizer = get_tokenizer(args.model)
self.tokenizer = get_tokenizer(args.model_path)
self.distribution = args.distribution
self.request_rate = args.request_rate
self.start_time = None
self.finished_time = None

self.sent_requests = 0
self.completed_requests = 0

self.candidate_inputs = sample_random_requests(
input_len=args.request_length,
output_len=args.output_length,
Expand Down Expand Up @@ -235,6 +263,18 @@ async def handle_request(self, item):
def request_sender(self):
async def request_loop():
while True:
if self.sent_requests - self.completed_requests < args.max_parallel:
new_request = self.ready_queue.pop()
if new_request:
asyncio.create_task(self.handle_request(new_request))
self.sent_requests += 1
else:
await asyncio.sleep(0.05)
continue

if self.pbar.n == self.pbar.total:
break

# Calculate Poisson-distributed wait time
if self.distribution == "poisson":
sleep_time = random.expovariate(self.request_rate)
Expand All @@ -247,14 +287,6 @@ async def request_loop():
raise ValueError("Invalid distribution type")
await asyncio.sleep(sleep_time) # Wait before sending the next request

new_request = self.ready_queue.pop()
# Submit async request
if new_request:
asyncio.create_task(self.handle_request(new_request))
else:
if self.pbar.n == self.pbar.total:
break

# Create and run the event loop for asynchronous requests
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
Expand All @@ -273,6 +305,7 @@ def response_handler(self):
self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
self.completed_requests += 1

if self.client_records[client_id]["round"] < args.num_rounds:
self.client_records[client_id][
Expand Down Expand Up @@ -301,34 +334,56 @@ def run(self):

request_thread.join()
response_thread.join()

self.pbar.close()
print("All requests completed.")

performance_data = {
"summary": {
"total_requests": len(self.performance_metrics["ttft"]),
"request_rate": self.request_rate,
"average_ttft": sum(self.performance_metrics["ttft"])
/ len(self.performance_metrics["ttft"]),
"p90_ttft": sorted(self.performance_metrics["ttft"])[
int(0.9 * len(self.performance_metrics["ttft"]))
],
"median_ttft": sorted(self.performance_metrics["ttft"])[
len(self.performance_metrics["ttft"]) // 2
],
"average_latency": sum(self.performance_metrics["latency"])
/ len(self.performance_metrics["latency"]),
"p90_latency": sorted(self.performance_metrics["latency"])[
int(0.9 * len(self.performance_metrics["latency"]))
],
"median_latency": sorted(self.performance_metrics["latency"])[
len(self.performance_metrics["latency"]) // 2
],
"throughput": self.pbar.total / (self.finished_time - self.start_time),
},
}
print("All requests completed")
print("Performance metrics summary:")
print(
f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second"
)
print(
f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}"
)
print(
f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}"
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
)
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
print(
f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}"
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
)
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
print(
f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}"
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
throughput = self.pbar.total / (self.finished_time - self.start_time)
print(f"Throughput: {throughput:.2f} requests per second")
log_to_jsonl_file(performance_data, args.log_file)


if __name__ == "__main__":
args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"

for request_rate in range(1, 41, 2):
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
args.request_rate = request_rate
requests.post(flush_cache_url)
time.sleep(1)
WorkloadGenerator(args).run()
Loading
Loading