From 91afbebd3dcbe014503742e9ec884859a7c36359 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 28 Jan 2025 22:16:47 -0500 Subject: [PATCH] [TPU] Add example for profiling TPU inference (#12531) Signed-off-by: mgoin --- .../offline_inference/profiling_tpu/README.md | 67 ++++++++++++ .../profiling_tpu/profiling.py | 101 ++++++++++++++++++ 2 files changed, 168 insertions(+) create mode 100644 examples/offline_inference/profiling_tpu/README.md create mode 100644 examples/offline_inference/profiling_tpu/profiling.py diff --git a/examples/offline_inference/profiling_tpu/README.md b/examples/offline_inference/profiling_tpu/README.md new file mode 100644 index 0000000000000..08efa63dc1021 --- /dev/null +++ b/examples/offline_inference/profiling_tpu/README.md @@ -0,0 +1,67 @@ +# vLLM TPU Profiling + +This script is used to profile the TPU performance of vLLM for specific prefill or decode token shapes. + +Note: an actual running server is a mix of both prefill of many shapes and decode of many shapes. + +We assume you are on a TPU already (this was tested on TPU v6e) and have installed vLLM according to the [installation guide](https://docs.vllm.ai/en/latest/getting_started/installation/ai_accelerator/index.html). + +> In all examples below, we run several warmups before (so `--enforce-eager` is okay) + +## Profile Examples + +### Generate Prefill Trace + +This example runs Qwen/Qwen2.5-7B-Instruct with a single request of 1024 input tokens. This is set up in attempt to profile just the prefill time and operations. + +```bash +export XLA_HLO_DEBUG=1 +export MODEL=Qwen/Qwen2.5-7B-Instruct +export VLLM_TPU_PROFILE_DURATION_MS=3000 +export VLLM_TPU_PROFILE_DELAY_MS=0 + +python3 profiling.py \ + --model $MODEL \ + --input-len 1024 --output-len 1 \ + --batch-size 1 --enforce-eager \ + --max-model-len 2048 \ + --tensor-parallel-size 1 \ + --profile-result-dir profiles +``` + + +### Generate Decode Trace + +This example runs Llama 3.1 70B with a batch of 32 requests where each has 1 input token and 128 output tokens. This is set up in attempt to profile just the 32 decodes running in parallel by having an extremely small prefill of 1 token and setting `VLLM_TPU_PROFILE_DELAY_MS=1000` to skip the first second of inference (hopefully prefill). + +```bash +export XLA_HLO_DEBUG=1 +export MODEL=meta-llama/Llama-3.1-70B-Instruct +export VLLM_TPU_PROFILE_DURATION_MS=2000 +export VLLM_TPU_PROFILE_DELAY_MS=1000 + +rm -rf ~/.cache/vllm/xla_cache +python3 profiling.py \ + --model $MODEL \ + --input-len 1 \ + --output-len 128 \ + --batch-size 32 \ + --enforce-eager \ + --profile-result-dir profiles \ + --max-model-len 2048 --tensor-parallel-size 8 +``` + + +## Visualizing the profiles + +Once you have collected your profiles with this script, you can visualize them using [TensorBoard](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm). + +Here are most likely the dependencies you need to install: +```bash +pip install tensorflow-cpu tensorboard-plugin-profile etils importlib_resources +``` + +Then you just need to point TensorBoard to the directory where you saved the profiles and visit `http://localhost:6006/` in your browser: +```bash +tensorboard --logdir profiles/ --port 6006 +``` \ No newline at end of file diff --git a/examples/offline_inference/profiling_tpu/profiling.py b/examples/offline_inference/profiling_tpu/profiling.py new file mode 100644 index 0000000000000..d7423e6c6da93 --- /dev/null +++ b/examples/offline_inference/profiling_tpu/profiling.py @@ -0,0 +1,101 @@ +import argparse +import dataclasses +import os +import time +from typing import List + +import numpy as np +import torch_xla.debug.profiler as xp +from tqdm import tqdm + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import PromptType +from vllm.utils import FlexibleArgumentParser + +DURATION_MS = int(os.getenv("VLLM_TPU_PROFILE_DURATION_MS", 3000)) +DELAY_MS = int(os.getenv("VLLM_TPU_PROFILE_DELAY_MS", 0)) + + +def main(args: argparse.Namespace): + print(args) + + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**dataclasses.asdict(engine_args)) + _ = xp.start_server(9012) + + sampling_params = SamplingParams( + temperature=0.0, + ignore_eos=True, + max_tokens=args.output_len, + ) + print(sampling_params) + dummy_prompt_token_ids = np.random.randint(10000, + size=(args.batch_size, + args.input_len)) + dummy_prompts: List[PromptType] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] + + def run_to_completion(): + start_time = time.perf_counter() + llm.generate(dummy_prompts, + sampling_params=sampling_params, + use_tqdm=False) + end_time = time.perf_counter() + latency = end_time - start_time + return latency + + # Warmup + print("Warming up...") + warmup_latencies = [] + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + warmup_latencies.append(run_to_completion()) + print(f"Average warmup latency: {np.mean(warmup_latencies):.4f}s") + + # Profile + profile_dir = args.profile_result_dir + print(f"Profiling (results will be saved to '{profile_dir}')...") + # Enable tracing on server + xp.trace_detached("localhost:9012", + profile_dir, + delay_ms=DELAY_MS, + duration_ms=DURATION_MS) + if DELAY_MS == 0: + time.sleep(1.0) + profile_latencies = [] + for _ in tqdm(range(args.num_iters), desc="Profile iterations"): + profile_latencies.append(run_to_completion()) + print(f"Average profile latency: {np.mean(profile_latencies):.4f}s") + + return + + +if __name__ == '__main__': + parser = FlexibleArgumentParser( + description='Benchmark the latency of processing a single batch of ' + 'requests till completion.') + parser.add_argument('--input-len', type=int, default=32) + parser.add_argument('--output-len', type=int, default=128) + parser.add_argument('--batch-size', type=int, default=8) + parser.add_argument('--num-iters-warmup', + type=int, + default=5, + help='Number of iterations to run for warmup.') + parser.add_argument('--num-iters', + type=int, + default=1, + help='Number of iterations to run for profiling.') + parser.add_argument( + '--profile-result-dir', + type=str, + default="profiles", + help= + ('path to save the pytorch profiler output. Can be visualized ' + 'with ui.perfetto.dev or Tensorboard ' + '(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).' + )) + + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + main(args)