From f325236ebb616733c27473512b8078e84a155866 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 12 Feb 2024 22:53:00 -0800 Subject: [PATCH] Serving Benchmark Refactoring (#2433) --- .buildkite/run-benchmarks.sh | 14 +- benchmarks/backend_request_func.py | 284 ++++++++++++++++++++++ benchmarks/benchmark_serving.py | 378 ++++++++++++++++++++--------- benchmarks/launch_tgi_server.sh | 2 +- 4 files changed, 553 insertions(+), 125 deletions(-) create mode 100644 benchmarks/backend_request_func.py diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index dde28cb55605e..865068628f1d0 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -6,15 +6,16 @@ set -o pipefail # cd into parent directory of this file cd "$(dirname "${BASH_SOURCE[0]}")/.." -(wget && curl) || (apt-get update && apt-get install -y wget curl) +(which wget && which curl) || (apt-get update && apt-get install -y wget curl) -# run benchmarks and upload the result to buildkite +# run python-based benchmarks and upload the result to buildkite python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt bench_latency_exit_code=$? python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt bench_throughput_exit_code=$? +# run server-based benchmarks and upload the result to buildkite python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf & server_pid=$! wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json @@ -22,11 +23,14 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r # wait for server to start, timeout after 600 seconds timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 python3 benchmarks/benchmark_serving.py \ + --backend openai \ --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \ --model meta-llama/Llama-2-7b-chat-hf \ --num-prompts 20 \ --endpoint /v1/completions \ - --tokenizer meta-llama/Llama-2-7b-chat-hf 2>&1 | tee benchmark_serving.txt + --tokenizer meta-llama/Llama-2-7b-chat-hf \ + --save-result \ + 2>&1 | tee benchmark_serving.txt bench_serving_exit_code=$? kill $server_pid @@ -44,7 +48,7 @@ sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line echo "### Serving Benchmarks" >> benchmark_results.md sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line echo "" >> benchmark_results.md -tail -n 5 benchmark_serving.txt >> benchmark_results.md # last 5 lines +tail -n 13 benchmark_serving.txt >> benchmark_results.md # last 13 lines # upload the results to buildkite /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md @@ -61,3 +65,5 @@ fi if [ $bench_serving_exit_code -ne 0 ]; then exit $bench_serving_exit_code fi + +/workspace/buildkite-agent artifact upload openai-*.json diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py new file mode 100644 index 0000000000000..e7f74e2feaf86 --- /dev/null +++ b/benchmarks/backend_request_func.py @@ -0,0 +1,284 @@ +import json +import os +import time +from dataclasses import dataclass +from typing import Optional + +import aiohttp +from tqdm.asyncio import tqdm + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +@dataclass +class RequestFuncInput: + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + best_of: int = 1 + use_beam_search: bool = False + + +@dataclass +class RequestFuncOutput: + generated_text: str = "" + success: bool = False + latency: float = 0 + ttft: float = 0 + prompt_len: int = 0 + + +async def async_request_tgi( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + assert not request_func_input.use_beam_search + params = { + "best_of": request_func_input.best_of, + "max_new_tokens": request_func_input.output_len, + "do_sample": True, + "temperature": 0.01, # TGI does not accept 0.0 temperature. + "top_p": 0.99, # TGI does not accept 1.0 top_p. + } + payload = { + "inputs": request_func_input.prompt, + "parameters": params, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0 + st = time.perf_counter() + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for data in response.content.iter_any(): + if ttft == 0: + ttft = time.perf_counter() - st + output.ttft = ttft + output.latency = time.perf_counter() - st + + body = data.decode("utf-8").lstrip("data:") + output.generated_text = json.loads(body)["generated_text"] + output.success = True + else: + output.success = False + except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + output.success = False + + if pbar: + pbar.update(1) + return output + + +async def async_request_vllm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "prompt": request_func_input.prompt, + "n": 1, + "best_of": request_func_input.best_of, + "use_beam_search": request_func_input.use_beam_search, + "temperature": 0.0 if request_func_input.use_beam_search else 1.0, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "ignore_eos": True, + "stream": True, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0 + st = time.perf_counter() + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for data in response.content.iter_any(): + if ttft == 0: + ttft = time.perf_counter() - st + output.ttft = ttft + output.latency = time.perf_counter() - st + + # When streaming, '\0' is appended to the end of the response. + body = data.decode("utf-8").strip("\0") + output.generated_text = json.loads( + body)["text"][0][len(request_func_input.prompt):] + output.success = True + + else: + output.success = False + except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + output.success = False + + if pbar: + pbar.update(1) + return output + + +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + assert not request_func_input.use_beam_search + assert request_func_input.best_of == 1 + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.0, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + ttft = 0 + + st = time.perf_counter() + try: + async with session.post(url=api_url, json=payload) as resp: + if resp.status == 200: + async for data in resp.content.iter_any(): + if ttft == 0: + ttft = time.perf_counter() - st + output.ttft = ttft + output.latency = time.perf_counter() - st + + body = data.decode("utf-8").lstrip("data:") + output.generated_text = json.loads(body)["text_output"] + output.success = True + + else: + output.success = False + except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + output.success = False + + if pbar: + pbar.update(1) + return output + + +async def async_request_deepspeed_mii( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + assert request_func_input.best_of == 1 + assert not request_func_input.use_beam_search + + payload = { + "prompts": request_func_input.prompt, + "max_new_tokens": request_func_input.output_len, + "ignore_eos": True, + "do_sample": True, + "temperature": + 0.01, # deepspeed-mii does not accept 0.0 temperature. + "top_p": 1.0, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder. + # https://github.com/microsoft/DeepSpeed-MII/pull/311 + output.ttft = 0 + + st = time.perf_counter() + try: + async with session.post(url=request_func_input.api_url, + json=payload) as resp: + if resp.status == 200: + parsed_resp = await resp.json() + output.latency = time.perf_counter() - st + output.generated_text = parsed_resp[0]["generated_text"] + output.success = True + else: + output.success = False + except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + output.success = False + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("v1/completions") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + assert not request_func_input.use_beam_search + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": request_func_input.best_of, + "max_tokens": request_func_input.output_len, + "stream": True, + } + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0 + st = time.perf_counter() + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + async for chunk in response.content: + if ttft == 0: + ttft = time.perf_counter() - st + output.ttft = ttft + + chunk = chunk.strip() + if not chunk: + continue + + chunk = chunk.decode("utf-8").lstrip("data: ") + if chunk == "[DONE]": + latency = time.perf_counter() - st + else: + body = json.loads(chunk) + generated_text += body["choices"][0]["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + else: + output.success = False + except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError): + output.success = False + + if pbar: + pbar.update(1) + return output + + +ASYNC_REQUEST_FUNCS = { + "tgi": async_request_tgi, + "vllm": async_request_vllm, + "deepspeed-mii": async_request_deepspeed_mii, + "openai": async_request_openai_completions, + "tensorrt-llm": async_request_trt_llm, +} diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 1a36d9d6a5deb..cdcfb8582143c 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -20,16 +20,36 @@ import json import random import time +from dataclasses import dataclass +from datetime import datetime from typing import AsyncGenerator, List, Tuple -import aiohttp import numpy as np from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase from vllm.transformers_utils.tokenizer import get_tokenizer -# (prompt len, output len, latency) -REQUEST_LATENCY: List[Tuple[int, int, float]] = [] +from backend_request_func import ( + ASYNC_REQUEST_FUNCS, + RequestFuncInput, + RequestFuncOutput, +) + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + request_throughput: float + input_throughput: float + output_throughput: float + mean_ttft_ms: float + median_ttft_ms: float + p99_ttft_ms: float + mean_tpot_ms: float + median_tpot_ms: float + p99_tpot_ms: float def sample_requests( @@ -46,6 +66,11 @@ def sample_requests( dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] + # some of these will be filtered out, so sample more than we need + sampled_indices = random.sample(range(len(dataset)), + int(num_requests * 1.2)) + dataset = [dataset[i] for i in sampled_indices] + # Tokenize the prompts and completions. prompts = [prompt for prompt, _ in dataset] prompt_token_ids = tokenizer(prompts).input_ids @@ -92,80 +117,125 @@ async def get_request( await asyncio.sleep(interval) -async def send_request(backend: str, model: str, api_url: str, prompt: str, - prompt_len: int, output_len: int, best_of: int, - use_beam_search: bool, pbar: tqdm) -> None: - request_start_time = time.perf_counter() - - headers = {"User-Agent": "Benchmark Client"} - if backend == "vllm": - pload = { - "prompt": prompt, - "n": 1, - "best_of": best_of, - "use_beam_search": use_beam_search, - "temperature": 0.0 if use_beam_search else 1.0, - "top_p": 1.0, - "max_tokens": output_len, - "ignore_eos": True, - "stream": False, - } - if model is not None: - pload["model"] = model - elif backend == "tgi": - assert not use_beam_search - params = { - "best_of": best_of, - "max_new_tokens": output_len, - "do_sample": True, - } - pload = { - "inputs": prompt, - "parameters": params, - } - else: - raise ValueError(f"Unknown backend: {backend}") - - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout) as session: - while True: - async with session.post(api_url, headers=headers, - json=pload) as response: - chunks = [] - async for chunk, _ in response.content.iter_chunks(): - chunks.append(chunk) - output = b"".join(chunks).decode("utf-8") - output = json.loads(output) +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, +) -> BenchmarkMetrics: + total_output = 0 + total_input = 0 + completed = 0 + per_token_latencies = [] + ttfts = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = len(tokenizer.encode(outputs[i].generated_text)) + total_output += output_len + total_input += input_requests[i][1] + per_token_latencies.append(outputs[i].latency / output_len) + ttfts.append(outputs[i].ttft) + completed += 1 - # Re-send the request if it failed. - if "error" not in output: - break + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=total_output, + request_throughput=completed / dur_s, + input_throughput=total_input / dur_s, + output_throughput=total_output / dur_s, + mean_ttft_ms=np.mean(ttfts) * 1000, + median_ttft_ms=np.median(ttfts) * 1000, + p99_ttft_ms=np.percentile(ttfts, 99) * 1000, + mean_tpot_ms=np.mean(per_token_latencies) * 1000, + median_tpot_ms=np.median(per_token_latencies) * 1000, + p99_tpot_ms=np.percentile(per_token_latencies, 99) * 1000, + ) - request_end_time = time.perf_counter() - request_latency = request_end_time - request_start_time - REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) - pbar.update(1) + return metrics async def benchmark( backend: str, - model: str, api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], best_of: int, use_beam_search: bool, request_rate: float, -) -> None: - tasks: List[asyncio.Task] = [] - pbar = tqdm(total=len(input_requests)) + disable_tqdm: bool, +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS.get(backend) + else: + raise ValueError(f"Unknown backend: {backend}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + print(f"Traffic request rate: {request_rate}") + + benchmark_start_time = time.perf_counter() + tasks = [] async for request in get_request(input_requests, request_rate): prompt, prompt_len, output_len = request - task = asyncio.create_task( - send_request(backend, model, api_url, prompt, prompt_len, - output_len, best_of, use_beam_search, pbar)) - tasks.append(task) - await asyncio.gather(*tasks) - pbar.close() + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + best_of=best_of, + use_beam_search=use_beam_search, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, + pbar=pbar))) + outputs = await asyncio.gather(*tasks) + + if not disable_tqdm: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + ) + + print(f"Successful requests: {metrics.completed}") + print(f"Benchmark duration: {benchmark_duration:2f} s") + print(f"Total input tokens: {metrics.total_input}") + print(f"Total generated tokens: {metrics.total_output}") + print(f"Request throughput: {metrics.request_throughput:.2f} requests/s") + print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s") + print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s") + print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms") + print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms") + print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms") + print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms") + print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms") + print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_inthroughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms + } + return result def main(args: argparse.Namespace): @@ -173,77 +243,145 @@ def main(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) - api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}" - tokenizer = get_tokenizer(args.tokenizer, + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + + tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) - benchmark_start_time = time.perf_counter() - asyncio.run( - benchmark(args.backend, args.model, api_url, input_requests, - args.best_of, args.use_beam_search, args.request_rate)) - benchmark_end_time = time.perf_counter() - benchmark_time = benchmark_end_time - benchmark_start_time - print(f"Total time: {benchmark_time:.2f} s") - print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") - - # Compute the latency statistics. - avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) - print(f"Average latency: {avg_latency:.2f} s") - avg_per_token_latency = np.mean([ - latency / (prompt_len + output_len) - for prompt_len, output_len, latency in REQUEST_LATENCY - ]) - print(f"Average latency per token: {avg_per_token_latency:.2f} s") - avg_per_output_token_latency = np.mean( - [latency / output_len for _, output_len, latency in REQUEST_LATENCY]) - print("Average latency per output token: " - f"{avg_per_output_token_latency:.2f} s") + benchmark_result = asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + best_of=args.best_of, + use_beam_search=args.use_beam_search, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + )) + + # Save config and results to json + if args.save_result: + result_json = {} + + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["backend"] = backend + result_json["version"] = args.version + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["best_of"] = args.best_of + result_json["use_beam_search"] = args.use_beam_search + result_json["num_prompts"] = args.num_prompts + + # Traffic + result_json["request_rate"] = ( + args.request_rate if args.request_rate < float("inf") else "inf") + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} + + # Save to file + base_model_id = model_id.split("/")[-1] + file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" + with open(file_name, "w") as outfile: + json.dump(result_json, outfile) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Benchmark the online serving throughput.") - parser.add_argument("--backend", - type=str, - default="vllm", - choices=["vllm", "tgi"]) - parser.add_argument("--protocol", - type=str, - default="http", - choices=["http", "https"]) + parser.add_argument( + "--backend", + type=str, + default="vllm", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--version", + type=str, + default="N/A", + help="Version of the serving backend/engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--endpoint", type=str, default="/generate") - parser.add_argument("--model", type=str, default=None) + parser.add_argument( + "--endpoint", + type=str, + default="/generate", + help="API endpoint.", + ) parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") - parser.add_argument("--tokenizer", - type=str, - required=True, - help="Name or path of the tokenizer.") - parser.add_argument("--best-of", - type=int, - default=1, - help="Generates `best_of` sequences per prompt and " - "returns the best one.") + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help= + "Name or path of the tokenizer, if not using the default model tokenizer.", + ) + parser.add_argument( + "--best-of", + type=int, + default=1, + help="Generates `best_of` sequences per prompt and " + "returns the best one.", + ) parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument("--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.") - parser.add_argument("--request-rate", - type=float, - default=float("inf"), - help="Number of requests per second. If this is inf, " - "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize " - "the request arrival times.") + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize " + "the request arrival times.", + ) parser.add_argument("--seed", type=int, default=0) - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disbale tqdm progress bar.", + ) + parser.add_argument( + "--save-result", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + args = parser.parse_args() main(args) diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index bdb25b78d85b4..64d3c4f4b3889 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -6,7 +6,7 @@ TOKENS=$2 docker run --gpus all --shm-size 1g -p $PORT:80 \ -v $PWD/data:/data \ - ghcr.io/huggingface/text-generation-inference:0.8 \ + ghcr.io/huggingface/text-generation-inference:1.4.0 \ --model-id $MODEL \ --sharded false \ --max-input-length 1024 \