Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 7ae99c2

Browse files
Initial Layerwise Profiler (#124)
### SUMMARY: Initial layerwise profiler leveraging the [kineto](https://github.com/pytorch/kineto) base [PyTorch profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html). NOTE: we run in eager mode by default so that the stack-trace/event-tree contains `nn.Module`, otherwise if cuda-graphs is used all the kernels will be under the `CUDAGraphRunner` module (this is converted to an nn.Module in this PR so that the `_build_module_tree` code puts all those kernels under a `CUDAGraphRunner`) NOTE: vllm kernels like `vllm::reshape_and_cache_kernel` have no trace or shape informat because they are not registered as a TorchOp (i.e. using `TORCH_LIBRARY`, they instead just a raw `PYBIND11_MODULE` module) Example on how to use visualization: ``` pip install -r requirements-dev.txt python examples/offline_profile.py --model nm-testing/OpenHermes-2.5-Mistral-7B-pruned50 --batch-size 4 --prompt-len 512 --json openhermes7b-dense ### For Breakdown Graphs # For module level breakdown python neuralmagic/tools/profiler/visualize_trace.py --json-trace openhermes7b-dense.json --output profile_dense.pdf # For kernel level breakdown python neuralmagic/tools/profiler/visualize_trace.py --json-trace openhermes7b-dense.json --output profile_dense_kernels.pdf --level kernel ### For table printing # Decode Summary Table python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase decode # Prefill Summary Table python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase prefill # Decode Model Layerwise Table python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase decode --table model # Prefill Model Layerwise Table python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase prefill --table model ``` Example Output: [profile-example-output.txt](https://github.com/neuralmagic/nm-vllm/files/14596545/profile-example-output.txt) ### TEST PLAN: - GHA --------- Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent b48fdbb commit 7ae99c2

File tree

8 files changed

+1034
-4
lines changed

8 files changed

+1034
-4
lines changed

examples/offline_profile.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import argparse
2+
import torch
3+
import sys
4+
import json
5+
import inspect
6+
7+
from dataclasses import dataclass, asdict
8+
from typing import Optional
9+
from vllm import LLM, SamplingParams
10+
from vllm.profiler import nm_profile
11+
12+
BATCH_SIZE_DEFAULT = 1
13+
PROMPT_LEN_DEFAULT = 256
14+
MAX_SEQ_LEN_DEFAULT = 1024
15+
16+
17+
@dataclass
18+
class ProfileContext:
19+
model: str
20+
model_revision: str
21+
sparsity: str
22+
quantization: str
23+
max_seq_len: int
24+
max_num_batched_tokens: int
25+
prompt_len: int
26+
batch_size: int
27+
tensor_parallel_size: int
28+
allow_cuda_graphs: bool
29+
30+
31+
def run_profile(context: ProfileContext, csv_output: Optional[str],
32+
json_output: Optional[str]):
33+
print("Run profile with:")
34+
for key, value in asdict(context).items():
35+
print(f" {key} = {value}")
36+
37+
# Create sampling params
38+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8)
39+
40+
# Create LLM
41+
llm = LLM(
42+
model=context.model,
43+
revision=context.model_revision,
44+
sparsity=context.sparsity,
45+
enforce_eager=not context.allow_cuda_graphs,
46+
tensor_parallel_size=context.tensor_parallel_size,
47+
gpu_memory_utilization=0.9,
48+
max_model_len=context.max_seq_len,
49+
quantization=context.quantization,
50+
max_num_batched_tokens=context.max_num_batched_tokens,
51+
)
52+
53+
batch_size = context.batch_size
54+
prompt_len = context.prompt_len
55+
56+
scheduler_config = llm.llm_engine.scheduler_config
57+
max_num_batched_tokens = scheduler_config.max_num_batched_tokens
58+
max_num_seqs = scheduler_config.max_num_seqs
59+
60+
if batch_size * prompt_len > max_num_batched_tokens:
61+
print(f"ERROR: chosen batch_size * prompt_len "
62+
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
63+
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
64+
f"and therefore cannot be run in a single profile step, please "
65+
f"choose a smaller batch size or prompt length, or increase "
66+
f"--max_num_batched_tokens")
67+
sys.exit(-1)
68+
if batch_size >= max_num_seqs:
69+
print(
70+
f"ERROR: chosen batch_size ({batch_size}) is larger than "
71+
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
72+
f"single profile step, please choose a smaller batch size")
73+
sys.exit(-1)
74+
75+
for i in range(batch_size):
76+
llm.llm_engine.add_request(
77+
request_id=f"seq{i}",
78+
prompt=None,
79+
prompt_token_ids=torch.randint(
80+
128, # 128 to skip over special tokens
81+
llm.llm_engine.model_config.get_vocab_size() // 2,
82+
size=(prompt_len, )).tolist(),
83+
sampling_params=sampling_params)
84+
85+
with nm_profile() as prefill_prof:
86+
llm.llm_engine.step() # First step is prefill
87+
88+
with nm_profile() as decode_prof:
89+
llm.llm_engine.step()
90+
91+
prefill_results = prefill_prof.results
92+
decode_results = decode_prof.results
93+
94+
print("=" * 80)
95+
print(f"= Prefill Model Table "
96+
f"(prompt_len={prompt_len}, batch_size={batch_size})")
97+
print("=" * 80)
98+
print()
99+
prefill_results.print_model_table()
100+
print()
101+
print("=" * 80)
102+
print(f"= Decode Model Table "
103+
f"(prompt_len={prompt_len}, batch_size={batch_size})")
104+
print("=" * 80)
105+
print()
106+
decode_results.print_model_table()
107+
print()
108+
print("=" * 80)
109+
print(f"= Prefill Summary Table "
110+
f"(prompt_len={prompt_len}, batch_size={batch_size})")
111+
print("=" * 80)
112+
print()
113+
prefill_results.print_summary_table()
114+
print()
115+
print("=" * 80)
116+
print(f"= Decode Summary Table "
117+
f"(prompt_len={prompt_len}, batch_size={batch_size})")
118+
print("=" * 80)
119+
print()
120+
decode_results.print_summary_table()
121+
122+
if csv_output:
123+
csv_filename_base = csv_output.rstrip(".csv")
124+
prefill_results.export_model_stats_table_csv(
125+
csv_filename_base + "_prefill_model_table.csv")
126+
prefill_results.export_summary_stats_table_csv(
127+
csv_filename_base + "_prefill_summary_table.csv")
128+
decode_results.export_model_stats_table_csv(\
129+
csv_filename_base + "_decode_model_table.csv")
130+
decode_results.export_summary_stats_table_csv(
131+
csv_filename_base + "_decode_summary_table.csv")
132+
133+
if json_output:
134+
cuda_devices = [
135+
torch.cuda.get_device_properties(dev_idx)
136+
for dev_idx in range(torch.cuda.device_count())
137+
]
138+
139+
json_dict = {
140+
"context": {
141+
"python_version": f"{sys.version}",
142+
"torch_version": f"{torch.__version__}",
143+
"torch_cuda_version": f"{torch.version.cuda}",
144+
"cuda_devices": f"{cuda_devices}",
145+
**asdict(context)
146+
},
147+
"prefill": prefill_results.convert_stats_to_dict(),
148+
"decode": decode_results.convert_stats_to_dict()
149+
}
150+
151+
with open(json_output.rstrip(".json") + ".json", "w+") as f:
152+
json.dump(json_dict, f, indent=2)
153+
pass
154+
155+
156+
if __name__ == "__main__":
157+
parser = argparse.ArgumentParser()
158+
159+
parser.add_argument(
160+
"--model",
161+
type=str,
162+
required=True,
163+
help='The name or path of a HuggingFace Transformers model.')
164+
parser.add_argument("--model-revision", type=str, default=None)
165+
parser.add_argument(
166+
"--csv",
167+
type=str,
168+
default=None,
169+
help="Export the results as multiple csv file. This should be the root "
170+
"filename, will create <filename>_prefill_model_table.csv, "
171+
"<filename>_prefill_summary_table.csv, "
172+
"<filename>_decode_model_table.csv, and "
173+
"<filename>_decode_summary_table.csv")
174+
parser.add_argument(
175+
"--json",
176+
type=str,
177+
default=None,
178+
help="Export the results as a json file. This should be the filename")
179+
parser.add_argument(
180+
"--sparsity",
181+
"-s",
182+
type=str,
183+
choices=[None, 'sparse_w16a16', 'semi_structured_sparse_w16a16'],
184+
help="Method used to compress sparse weights. If "
185+
"None, we first check the `sparsity_config` attribute"
186+
"in the model config file. If that is None we assume"
187+
"the model weights are dense")
188+
parser.add_argument(
189+
"--quantization",
190+
"-q",
191+
type=str,
192+
choices=['awq', 'gptq', 'squeezellm', 'marlin', None],
193+
default=None,
194+
help="The method used to quantize the model weights, "
195+
"options are \"marlin\", \"awq\", \"gptq\" and \"squeezellm\"")
196+
parser.add_argument(
197+
"--max-seq-len",
198+
type=int,
199+
default=MAX_SEQ_LEN_DEFAULT,
200+
help=f"Maximum length of a sequence (including prompt and output), "
201+
f"default={MAX_SEQ_LEN_DEFAULT}")
202+
parser.add_argument(
203+
"--max-num-batched-tokens",
204+
type=int,
205+
default=None,
206+
help="Maximum number of tokens to be processed in a single iteration. "
207+
" Should be greater than batch-size * prompt-len so the prefill can "
208+
" run in a single iteration.")
209+
parser.add_argument(
210+
"--prompt-len",
211+
type=int,
212+
default=PROMPT_LEN_DEFAULT,
213+
help=f"Length of the random prompt to use when profiling, all batched "
214+
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
215+
parser.add_argument("--batch-size",
216+
type=int,
217+
default=BATCH_SIZE_DEFAULT,
218+
help=f"Number of requests to run as a single batch, "
219+
f"default={BATCH_SIZE_DEFAULT}")
220+
parser.add_argument("--tensor-parallel-size",
221+
"-tp",
222+
type=int,
223+
default=1,
224+
help="Number of GPUs to use i.e. tensor parallelism, "
225+
"default=1")
226+
parser.add_argument(
227+
"--allow-cuda-graphs",
228+
action='store_true',
229+
help="Enables cuda graphs to be used, well remove a lot of the module "
230+
"level info in the profiler results since almost everything runs in "
231+
"the graph where we do not have access to an informative stack trace")
232+
233+
args = parser.parse_args()
234+
235+
context = ProfileContext(
236+
**{
237+
k: v
238+
for k, v in vars(args).items()
239+
if k in inspect.signature(ProfileContext).parameters
240+
})
241+
run_profile(context, csv_output=args.csv, json_output=args.json)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import argparse
2+
import json
3+
4+
from vllm.profiler.nm_profile import SummaryStatsEntry, ModelStatsEntry
5+
from vllm.profiler.utils import indent_string, TablePrinter
6+
from typing import Dict
7+
8+
9+
def flatten_entries(entry_cls, profile_dict: Dict):
10+
entries_and_depth = []
11+
12+
def get_entries(node, curr_depth=0):
13+
entries_and_depth.append((entry_cls(**node["entry"]), curr_depth))
14+
15+
for child in node["children"]:
16+
get_entries(
17+
child,
18+
curr_depth=curr_depth + 1,
19+
)
20+
21+
for root in profile_dict:
22+
get_entries(root)
23+
24+
return entries_and_depth
25+
26+
27+
if __name__ == "__main__":
28+
parser = argparse.ArgumentParser()
29+
30+
parser.add_argument("--json-trace",
31+
type=str,
32+
required=True,
33+
help="json trace file output by "
34+
"examples/offline_profile.py")
35+
parser.add_argument("--phase",
36+
type=str,
37+
choices=["prefill", "decode"],
38+
required=True,
39+
help="The phase to print the table for.")
40+
parser.add_argument("--table",
41+
type=str,
42+
choices=["summary", "model"],
43+
default="summary",
44+
help="Which table to print, the summary table or the "
45+
"layerwise model table")
46+
47+
args = parser.parse_args()
48+
49+
with open(args.json_trace, "r") as f:
50+
profile_data = json.load(f)
51+
52+
if args.table == "summary":
53+
entries_and_depths = flatten_entries(
54+
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
55+
column_widths = dict(name=80,
56+
cuda_time_us=12,
57+
pct_cuda_time=12,
58+
invocations=15)
59+
elif args.table == "model":
60+
entries_and_depths = flatten_entries(
61+
ModelStatsEntry, profile_data[args.phase]["model_stats"])
62+
column_widths = dict(name=60,
63+
cpu_time_us=12,
64+
cuda_time_us=12,
65+
pct_cuda_time=12,
66+
trace=60)
67+
68+
# ident entry names based on the depth
69+
entries = []
70+
for entry, depth in entries_and_depths:
71+
entry.name = indent_string(
72+
entry.name,
73+
indent=depth,
74+
indent_style=lambda indent: "|" + "-" * indent + " ")
75+
entries.append(entry)
76+
77+
TablePrinter(type(entries[0]), column_widths).print_table(entries)

0 commit comments

Comments
 (0)