This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
### SUMMARY: Initial layerwise profiler leveraging the [kineto](https://github.com/pytorch/kineto) base [PyTorch profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html). NOTE: we run in eager mode by default so that the stack-trace/event-tree contains `nn.Module`, otherwise if cuda-graphs is used all the kernels will be under the `CUDAGraphRunner` module (this is converted to an nn.Module in this PR so that the `_build_module_tree` code puts all those kernels under a `CUDAGraphRunner`) NOTE: vllm kernels like `vllm::reshape_and_cache_kernel` have no trace or shape informat because they are not registered as a TorchOp (i.e. using `TORCH_LIBRARY`, they instead just a raw `PYBIND11_MODULE` module) Example on how to use visualization: ``` pip install -r requirements-dev.txt python examples/offline_profile.py --model nm-testing/OpenHermes-2.5-Mistral-7B-pruned50 --batch-size 4 --prompt-len 512 --json openhermes7b-dense ### For Breakdown Graphs # For module level breakdown python neuralmagic/tools/profiler/visualize_trace.py --json-trace openhermes7b-dense.json --output profile_dense.pdf # For kernel level breakdown python neuralmagic/tools/profiler/visualize_trace.py --json-trace openhermes7b-dense.json --output profile_dense_kernels.pdf --level kernel ### For table printing # Decode Summary Table python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase decode # Prefill Summary Table python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase prefill # Decode Model Layerwise Table python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase decode --table model # Prefill Model Layerwise Table python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase prefill --table model ``` Example Output: [profile-example-output.txt](https://github.com/neuralmagic/nm-vllm/files/14596545/profile-example-output.txt) ### TEST PLAN: - GHA --------- Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
- Loading branch information
1 parent
b48fdbb
commit 7ae99c2
Showing
8 changed files
with
1,034 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
import argparse | ||
import torch | ||
import sys | ||
import json | ||
import inspect | ||
|
||
from dataclasses import dataclass, asdict | ||
from typing import Optional | ||
from vllm import LLM, SamplingParams | ||
from vllm.profiler import nm_profile | ||
|
||
BATCH_SIZE_DEFAULT = 1 | ||
PROMPT_LEN_DEFAULT = 256 | ||
MAX_SEQ_LEN_DEFAULT = 1024 | ||
|
||
|
||
@dataclass | ||
class ProfileContext: | ||
model: str | ||
model_revision: str | ||
sparsity: str | ||
quantization: str | ||
max_seq_len: int | ||
max_num_batched_tokens: int | ||
prompt_len: int | ||
batch_size: int | ||
tensor_parallel_size: int | ||
allow_cuda_graphs: bool | ||
|
||
|
||
def run_profile(context: ProfileContext, csv_output: Optional[str], | ||
json_output: Optional[str]): | ||
print("Run profile with:") | ||
for key, value in asdict(context).items(): | ||
print(f" {key} = {value}") | ||
|
||
# Create sampling params | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8) | ||
|
||
# Create LLM | ||
llm = LLM( | ||
model=context.model, | ||
revision=context.model_revision, | ||
sparsity=context.sparsity, | ||
enforce_eager=not context.allow_cuda_graphs, | ||
tensor_parallel_size=context.tensor_parallel_size, | ||
gpu_memory_utilization=0.9, | ||
max_model_len=context.max_seq_len, | ||
quantization=context.quantization, | ||
max_num_batched_tokens=context.max_num_batched_tokens, | ||
) | ||
|
||
batch_size = context.batch_size | ||
prompt_len = context.prompt_len | ||
|
||
scheduler_config = llm.llm_engine.scheduler_config | ||
max_num_batched_tokens = scheduler_config.max_num_batched_tokens | ||
max_num_seqs = scheduler_config.max_num_seqs | ||
|
||
if batch_size * prompt_len > max_num_batched_tokens: | ||
print(f"ERROR: chosen batch_size * prompt_len " | ||
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " | ||
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " | ||
f"and therefore cannot be run in a single profile step, please " | ||
f"choose a smaller batch size or prompt length, or increase " | ||
f"--max_num_batched_tokens") | ||
sys.exit(-1) | ||
if batch_size >= max_num_seqs: | ||
print( | ||
f"ERROR: chosen batch_size ({batch_size}) is larger than " | ||
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " | ||
f"single profile step, please choose a smaller batch size") | ||
sys.exit(-1) | ||
|
||
for i in range(batch_size): | ||
llm.llm_engine.add_request( | ||
request_id=f"seq{i}", | ||
prompt=None, | ||
prompt_token_ids=torch.randint( | ||
128, # 128 to skip over special tokens | ||
llm.llm_engine.model_config.get_vocab_size() // 2, | ||
size=(prompt_len, )).tolist(), | ||
sampling_params=sampling_params) | ||
|
||
with nm_profile() as prefill_prof: | ||
llm.llm_engine.step() # First step is prefill | ||
|
||
with nm_profile() as decode_prof: | ||
llm.llm_engine.step() | ||
|
||
prefill_results = prefill_prof.results | ||
decode_results = decode_prof.results | ||
|
||
print("=" * 80) | ||
print(f"= Prefill Model Table " | ||
f"(prompt_len={prompt_len}, batch_size={batch_size})") | ||
print("=" * 80) | ||
print() | ||
prefill_results.print_model_table() | ||
print() | ||
print("=" * 80) | ||
print(f"= Decode Model Table " | ||
f"(prompt_len={prompt_len}, batch_size={batch_size})") | ||
print("=" * 80) | ||
print() | ||
decode_results.print_model_table() | ||
print() | ||
print("=" * 80) | ||
print(f"= Prefill Summary Table " | ||
f"(prompt_len={prompt_len}, batch_size={batch_size})") | ||
print("=" * 80) | ||
print() | ||
prefill_results.print_summary_table() | ||
print() | ||
print("=" * 80) | ||
print(f"= Decode Summary Table " | ||
f"(prompt_len={prompt_len}, batch_size={batch_size})") | ||
print("=" * 80) | ||
print() | ||
decode_results.print_summary_table() | ||
|
||
if csv_output: | ||
csv_filename_base = csv_output.rstrip(".csv") | ||
prefill_results.export_model_stats_table_csv( | ||
csv_filename_base + "_prefill_model_table.csv") | ||
prefill_results.export_summary_stats_table_csv( | ||
csv_filename_base + "_prefill_summary_table.csv") | ||
decode_results.export_model_stats_table_csv(\ | ||
csv_filename_base + "_decode_model_table.csv") | ||
decode_results.export_summary_stats_table_csv( | ||
csv_filename_base + "_decode_summary_table.csv") | ||
|
||
if json_output: | ||
cuda_devices = [ | ||
torch.cuda.get_device_properties(dev_idx) | ||
for dev_idx in range(torch.cuda.device_count()) | ||
] | ||
|
||
json_dict = { | ||
"context": { | ||
"python_version": f"{sys.version}", | ||
"torch_version": f"{torch.__version__}", | ||
"torch_cuda_version": f"{torch.version.cuda}", | ||
"cuda_devices": f"{cuda_devices}", | ||
**asdict(context) | ||
}, | ||
"prefill": prefill_results.convert_stats_to_dict(), | ||
"decode": decode_results.convert_stats_to_dict() | ||
} | ||
|
||
with open(json_output.rstrip(".json") + ".json", "w+") as f: | ||
json.dump(json_dict, f, indent=2) | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
required=True, | ||
help='The name or path of a HuggingFace Transformers model.') | ||
parser.add_argument("--model-revision", type=str, default=None) | ||
parser.add_argument( | ||
"--csv", | ||
type=str, | ||
default=None, | ||
help="Export the results as multiple csv file. This should be the root " | ||
"filename, will create <filename>_prefill_model_table.csv, " | ||
"<filename>_prefill_summary_table.csv, " | ||
"<filename>_decode_model_table.csv, and " | ||
"<filename>_decode_summary_table.csv") | ||
parser.add_argument( | ||
"--json", | ||
type=str, | ||
default=None, | ||
help="Export the results as a json file. This should be the filename") | ||
parser.add_argument( | ||
"--sparsity", | ||
"-s", | ||
type=str, | ||
choices=[None, 'sparse_w16a16', 'semi_structured_sparse_w16a16'], | ||
help="Method used to compress sparse weights. If " | ||
"None, we first check the `sparsity_config` attribute" | ||
"in the model config file. If that is None we assume" | ||
"the model weights are dense") | ||
parser.add_argument( | ||
"--quantization", | ||
"-q", | ||
type=str, | ||
choices=['awq', 'gptq', 'squeezellm', 'marlin', None], | ||
default=None, | ||
help="The method used to quantize the model weights, " | ||
"options are \"marlin\", \"awq\", \"gptq\" and \"squeezellm\"") | ||
parser.add_argument( | ||
"--max-seq-len", | ||
type=int, | ||
default=MAX_SEQ_LEN_DEFAULT, | ||
help=f"Maximum length of a sequence (including prompt and output), " | ||
f"default={MAX_SEQ_LEN_DEFAULT}") | ||
parser.add_argument( | ||
"--max-num-batched-tokens", | ||
type=int, | ||
default=None, | ||
help="Maximum number of tokens to be processed in a single iteration. " | ||
" Should be greater than batch-size * prompt-len so the prefill can " | ||
" run in a single iteration.") | ||
parser.add_argument( | ||
"--prompt-len", | ||
type=int, | ||
default=PROMPT_LEN_DEFAULT, | ||
help=f"Length of the random prompt to use when profiling, all batched " | ||
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") | ||
parser.add_argument("--batch-size", | ||
type=int, | ||
default=BATCH_SIZE_DEFAULT, | ||
help=f"Number of requests to run as a single batch, " | ||
f"default={BATCH_SIZE_DEFAULT}") | ||
parser.add_argument("--tensor-parallel-size", | ||
"-tp", | ||
type=int, | ||
default=1, | ||
help="Number of GPUs to use i.e. tensor parallelism, " | ||
"default=1") | ||
parser.add_argument( | ||
"--allow-cuda-graphs", | ||
action='store_true', | ||
help="Enables cuda graphs to be used, well remove a lot of the module " | ||
"level info in the profiler results since almost everything runs in " | ||
"the graph where we do not have access to an informative stack trace") | ||
|
||
args = parser.parse_args() | ||
|
||
context = ProfileContext( | ||
**{ | ||
k: v | ||
for k, v in vars(args).items() | ||
if k in inspect.signature(ProfileContext).parameters | ||
}) | ||
run_profile(context, csv_output=args.csv, json_output=args.json) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import argparse | ||
import json | ||
|
||
from vllm.profiler.nm_profile import SummaryStatsEntry, ModelStatsEntry | ||
from vllm.profiler.utils import indent_string, TablePrinter | ||
from typing import Dict | ||
|
||
|
||
def flatten_entries(entry_cls, profile_dict: Dict): | ||
entries_and_depth = [] | ||
|
||
def get_entries(node, curr_depth=0): | ||
entries_and_depth.append((entry_cls(**node["entry"]), curr_depth)) | ||
|
||
for child in node["children"]: | ||
get_entries( | ||
child, | ||
curr_depth=curr_depth + 1, | ||
) | ||
|
||
for root in profile_dict: | ||
get_entries(root) | ||
|
||
return entries_and_depth | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument("--json-trace", | ||
type=str, | ||
required=True, | ||
help="json trace file output by " | ||
"examples/offline_profile.py") | ||
parser.add_argument("--phase", | ||
type=str, | ||
choices=["prefill", "decode"], | ||
required=True, | ||
help="The phase to print the table for.") | ||
parser.add_argument("--table", | ||
type=str, | ||
choices=["summary", "model"], | ||
default="summary", | ||
help="Which table to print, the summary table or the " | ||
"layerwise model table") | ||
|
||
args = parser.parse_args() | ||
|
||
with open(args.json_trace, "r") as f: | ||
profile_data = json.load(f) | ||
|
||
if args.table == "summary": | ||
entries_and_depths = flatten_entries( | ||
SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) | ||
column_widths = dict(name=80, | ||
cuda_time_us=12, | ||
pct_cuda_time=12, | ||
invocations=15) | ||
elif args.table == "model": | ||
entries_and_depths = flatten_entries( | ||
ModelStatsEntry, profile_data[args.phase]["model_stats"]) | ||
column_widths = dict(name=60, | ||
cpu_time_us=12, | ||
cuda_time_us=12, | ||
pct_cuda_time=12, | ||
trace=60) | ||
|
||
# ident entry names based on the depth | ||
entries = [] | ||
for entry, depth in entries_and_depths: | ||
entry.name = indent_string( | ||
entry.name, | ||
indent=depth, | ||
indent_style=lambda indent: "|" + "-" * indent + " ") | ||
entries.append(entry) | ||
|
||
TablePrinter(type(entries[0]), column_widths).print_table(entries) |
Oops, something went wrong.