Skip to content

Commit

Permalink
RPD Profiling (vllm-project#208)
Browse files Browse the repository at this point in the history
* Enable RPD for single/multi gpu

Co-authored-by: AdrianAbeyta <adrian.abeyta@amd.com>

* Add rpd build instructions to Dockerfile.rocm

* Handle env path

* Fix code errors

* Move RPD based profiling over to profiling folder

* use envs vs os.getenv

---------

Co-authored-by: AdrianAbeyta <adrian.abeyta@amd.com>
  • Loading branch information
dllehr-amd and AdrianAbeyta authored Sep 27, 2024
1 parent b79f9f4 commit a87da2b
Show file tree
Hide file tree
Showing 9 changed files with 1,102 additions and 4 deletions.
6 changes: 6 additions & 0 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \

RUN python3 -m pip install --upgrade numba scipy huggingface-hub[cli]

RUN git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git \
&& cd rocmProfileData/rpd_tracer
&& pip install -r requirements.txt && cd ../ \
&& make && make install \
cd hipMarker && python setup.py install

# Install vLLM (and gradlib)
# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
Expand Down
330 changes: 330 additions & 0 deletions benchmarks/profiling/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,330 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import json
import os
import time
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import List, Optional

import numpy as np
import torch
from rpdTracerControl import rpdTracerControl as rpd
from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser


def main(args: argparse.Namespace):
print(args)

@contextmanager
def rpd_profiler_context():
llm.start_profile()
yield
llm.stop_profile()
rpd.top_totals()

@contextmanager
def torch_profiler_context(profile_dir: Optional[str] = None,
trace_file_name=None):
p = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir)))
p.start()
try:
with torch.no_grad():
yield p
finally:
p.stop()
print(p.key_averages().table(sort_by="self_cuda_time_total",
row_limit=-1))

def get_profiling_context(profile_dir: Optional[str] = None,
trace_file_name=None):
if args.profile_torch:
return torch_profiler_context(profile_dir, trace_file_name)
elif args.profile_rpd:
return rpd_profiler_context()
else:
return nullcontext()

# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
speculative_draft_tensor_parallel_size=\
args.speculative_draft_tensor_parallel_size,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
max_model_len=args.max_model_len,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size,
gpu_memory_utilization=args.gpu_memory_utilization,
load_format=args.load_format,
distributed_executor_backend=args.distributed_executor_backend,
otlp_traces_endpoint=args.otlp_traces_endpoint,
enable_prefix_caching=args.enable_prefix_caching,
num_scheduler_steps=args.num_scheduler_steps,
)

sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=args.output_len,
)
print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with get_profiling_context():
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
else:
start_time = time.perf_counter()
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
latency = end_time - start_time
return latency

print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None)

if args.profile_torch or args.profile_rpd:
profile_dir = args.profile_dir
if not profile_dir:
profile_dir = Path(".") / "vllm_benchmark_latency_result"
os.makedirs(profile_dir, exist_ok=True)
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return

# Benchmark.
latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90, 99]
percentiles = np.percentile(latencies, percentages)
print(f'Avg latency: {np.mean(latencies)} seconds')
for percentage, percentile in zip(percentages, percentiles):
print(f'{percentage}% percentile latency: {percentile} seconds')

# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)


if __name__ == '__main__':
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=None)
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters-warmup',
type=int,
default=10,
help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters',
type=int,
default=30,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
'--quantized-weights-path',
type=str,
default=None,
help='Path to the safetensor file containing the quantized weights '
'and scaling factors. This should generally be supplied, when '
'quantization is FP8.')
parser.add_argument(
'--profile-torch',
action='store_true',
help='profile the generation process of a single batch')
parser.add_argument(
'--profile-rpd',
action='store_true',
help='profile the generation process of a single batch')
parser.add_argument(
'--profile-dir',
type=str,
default=os.getenv('VLLM_RPD_PROFILER_DIR', default=None),
help=('path to save the profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'))
parser.add_argument("--device",
type=str,
default="auto",
choices=DEVICE_OPTIONS,
help='device type for vLLM execution')
parser.add_argument('--block-size',
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument(
'--enable-chunked-prefill',
action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument("--enable-prefix-caching",
action='store_true',
help="Enable automatic prefix caching")
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
help="If specified, use nsight to profile ray workers",
)
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the latency results in JSON format.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
'bitsandbytes'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available.\n'
'* "pt" will load the weights in the pytorch bin format.\n'
'* "safetensors" will load the weights in the safetensors format.\n'
'* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading.\n'
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
parser.add_argument(
'--otlp-traces-endpoint',
type=str,
default=None,
help='Target URL to which OpenTelemetry traces will be sent.')
parser.add_argument(
"--num-scheduler-steps",
type=int,
default=1,
help="Maximum number of forward steps per scheduler call.")
args = parser.parse_args()
main(args)
Loading

0 comments on commit a87da2b

Please sign in to comment.