Skip to content

Commit

Permalink
Fix progress bar and allow HTTPS in benchmark_serving.py (vllm-proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
hmellor authored Jan 22, 2024
1 parent 10c2c70 commit 1a6a22c
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ async def send_request(
output_len: int,
best_of: int,
use_beam_search: bool,
pbar: tqdm
) -> None:
request_start_time = time.perf_counter()

Expand Down Expand Up @@ -151,6 +152,8 @@ async def send_request(
request_end_time = time.perf_counter()
request_latency = request_end_time - request_start_time
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
pbar.update(1)



async def benchmark(
Expand All @@ -163,21 +166,23 @@ async def benchmark(
request_rate: float,
) -> None:
tasks: List[asyncio.Task] = []
pbar = tqdm(total=len(input_requests))
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
task = asyncio.create_task(
send_request(backend, model, api_url, prompt, prompt_len,
output_len, best_of, use_beam_search))
output_len, best_of, use_beam_search, pbar))
tasks.append(task)
await tqdm.gather(*tasks)
await asyncio.gather(*tasks)
pbar.close()


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)

api_url = f"http://{args.host}:{args.port}{args.endpoint}"
api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
Expand Down Expand Up @@ -212,6 +217,7 @@ def main(args: argparse.Namespace):
type=str,
default="vllm",
choices=["vllm", "tgi"])
parser.add_argument("--protocol", type=str, default="http", choices=["http", "https"])
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--endpoint", type=str, default="/generate")
Expand Down

0 comments on commit 1a6a22c

Please sign in to comment.