forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TPU] Add example for profiling TPU inference (vllm-project#12531)
Signed-off-by: mgoin <mgoin@redhat.com>
- Loading branch information
Showing
2 changed files
with
168 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# vLLM TPU Profiling | ||
|
||
This script is used to profile the TPU performance of vLLM for specific prefill or decode token shapes. | ||
|
||
Note: an actual running server is a mix of both prefill of many shapes and decode of many shapes. | ||
|
||
We assume you are on a TPU already (this was tested on TPU v6e) and have installed vLLM according to the [installation guide](https://docs.vllm.ai/en/latest/getting_started/installation/ai_accelerator/index.html). | ||
|
||
> In all examples below, we run several warmups before (so `--enforce-eager` is okay) | ||
## Profile Examples | ||
|
||
### Generate Prefill Trace | ||
|
||
This example runs Qwen/Qwen2.5-7B-Instruct with a single request of 1024 input tokens. This is set up in attempt to profile just the prefill time and operations. | ||
|
||
```bash | ||
export XLA_HLO_DEBUG=1 | ||
export MODEL=Qwen/Qwen2.5-7B-Instruct | ||
export VLLM_TPU_PROFILE_DURATION_MS=3000 | ||
export VLLM_TPU_PROFILE_DELAY_MS=0 | ||
|
||
python3 profiling.py \ | ||
--model $MODEL \ | ||
--input-len 1024 --output-len 1 \ | ||
--batch-size 1 --enforce-eager \ | ||
--max-model-len 2048 \ | ||
--tensor-parallel-size 1 \ | ||
--profile-result-dir profiles | ||
``` | ||
|
||
|
||
### Generate Decode Trace | ||
|
||
This example runs Llama 3.1 70B with a batch of 32 requests where each has 1 input token and 128 output tokens. This is set up in attempt to profile just the 32 decodes running in parallel by having an extremely small prefill of 1 token and setting `VLLM_TPU_PROFILE_DELAY_MS=1000` to skip the first second of inference (hopefully prefill). | ||
|
||
```bash | ||
export XLA_HLO_DEBUG=1 | ||
export MODEL=meta-llama/Llama-3.1-70B-Instruct | ||
export VLLM_TPU_PROFILE_DURATION_MS=2000 | ||
export VLLM_TPU_PROFILE_DELAY_MS=1000 | ||
|
||
rm -rf ~/.cache/vllm/xla_cache | ||
python3 profiling.py \ | ||
--model $MODEL \ | ||
--input-len 1 \ | ||
--output-len 128 \ | ||
--batch-size 32 \ | ||
--enforce-eager \ | ||
--profile-result-dir profiles \ | ||
--max-model-len 2048 --tensor-parallel-size 8 | ||
``` | ||
|
||
|
||
## Visualizing the profiles | ||
|
||
Once you have collected your profiles with this script, you can visualize them using [TensorBoard](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm). | ||
|
||
Here are most likely the dependencies you need to install: | ||
```bash | ||
pip install tensorflow-cpu tensorboard-plugin-profile etils importlib_resources | ||
``` | ||
|
||
Then you just need to point TensorBoard to the directory where you saved the profiles and visit `http://localhost:6006/` in your browser: | ||
```bash | ||
tensorboard --logdir profiles/ --port 6006 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import argparse | ||
import dataclasses | ||
import os | ||
import time | ||
from typing import List | ||
|
||
import numpy as np | ||
import torch_xla.debug.profiler as xp | ||
from tqdm import tqdm | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.inputs import PromptType | ||
from vllm.utils import FlexibleArgumentParser | ||
|
||
DURATION_MS = int(os.getenv("VLLM_TPU_PROFILE_DURATION_MS", 3000)) | ||
DELAY_MS = int(os.getenv("VLLM_TPU_PROFILE_DELAY_MS", 0)) | ||
|
||
|
||
def main(args: argparse.Namespace): | ||
print(args) | ||
|
||
engine_args = EngineArgs.from_cli_args(args) | ||
llm = LLM(**dataclasses.asdict(engine_args)) | ||
_ = xp.start_server(9012) | ||
|
||
sampling_params = SamplingParams( | ||
temperature=0.0, | ||
ignore_eos=True, | ||
max_tokens=args.output_len, | ||
) | ||
print(sampling_params) | ||
dummy_prompt_token_ids = np.random.randint(10000, | ||
size=(args.batch_size, | ||
args.input_len)) | ||
dummy_prompts: List[PromptType] = [{ | ||
"prompt_token_ids": batch | ||
} for batch in dummy_prompt_token_ids.tolist()] | ||
|
||
def run_to_completion(): | ||
start_time = time.perf_counter() | ||
llm.generate(dummy_prompts, | ||
sampling_params=sampling_params, | ||
use_tqdm=False) | ||
end_time = time.perf_counter() | ||
latency = end_time - start_time | ||
return latency | ||
|
||
# Warmup | ||
print("Warming up...") | ||
warmup_latencies = [] | ||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): | ||
warmup_latencies.append(run_to_completion()) | ||
print(f"Average warmup latency: {np.mean(warmup_latencies):.4f}s") | ||
|
||
# Profile | ||
profile_dir = args.profile_result_dir | ||
print(f"Profiling (results will be saved to '{profile_dir}')...") | ||
# Enable tracing on server | ||
xp.trace_detached("localhost:9012", | ||
profile_dir, | ||
delay_ms=DELAY_MS, | ||
duration_ms=DURATION_MS) | ||
if DELAY_MS == 0: | ||
time.sleep(1.0) | ||
profile_latencies = [] | ||
for _ in tqdm(range(args.num_iters), desc="Profile iterations"): | ||
profile_latencies.append(run_to_completion()) | ||
print(f"Average profile latency: {np.mean(profile_latencies):.4f}s") | ||
|
||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = FlexibleArgumentParser( | ||
description='Benchmark the latency of processing a single batch of ' | ||
'requests till completion.') | ||
parser.add_argument('--input-len', type=int, default=32) | ||
parser.add_argument('--output-len', type=int, default=128) | ||
parser.add_argument('--batch-size', type=int, default=8) | ||
parser.add_argument('--num-iters-warmup', | ||
type=int, | ||
default=5, | ||
help='Number of iterations to run for warmup.') | ||
parser.add_argument('--num-iters', | ||
type=int, | ||
default=1, | ||
help='Number of iterations to run for profiling.') | ||
parser.add_argument( | ||
'--profile-result-dir', | ||
type=str, | ||
default="profiles", | ||
help= | ||
('path to save the pytorch profiler output. Can be visualized ' | ||
'with ui.perfetto.dev or Tensorboard ' | ||
'(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).' | ||
)) | ||
|
||
parser = EngineArgs.add_cli_args(parser) | ||
args = parser.parse_args() | ||
main(args) |