diff --git a/examples/offline_profile.py b/examples/offline_profile.py new file mode 100644 index 0000000000000..e7f66876ff68b --- /dev/null +++ b/examples/offline_profile.py @@ -0,0 +1,241 @@ +import argparse +import torch +import sys +import json +import inspect + +from dataclasses import dataclass, asdict +from typing import Optional +from vllm import LLM, SamplingParams +from vllm.profiler import nm_profile + +BATCH_SIZE_DEFAULT = 1 +PROMPT_LEN_DEFAULT = 256 +MAX_SEQ_LEN_DEFAULT = 1024 + + +@dataclass +class ProfileContext: + model: str + model_revision: str + sparsity: str + quantization: str + max_seq_len: int + max_num_batched_tokens: int + prompt_len: int + batch_size: int + tensor_parallel_size: int + allow_cuda_graphs: bool + + +def run_profile(context: ProfileContext, csv_output: Optional[str], + json_output: Optional[str]): + print("Run profile with:") + for key, value in asdict(context).items(): + print(f" {key} = {value}") + + # Create sampling params + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8) + + # Create LLM + llm = LLM( + model=context.model, + revision=context.model_revision, + sparsity=context.sparsity, + enforce_eager=not context.allow_cuda_graphs, + tensor_parallel_size=context.tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=context.max_seq_len, + quantization=context.quantization, + max_num_batched_tokens=context.max_num_batched_tokens, + ) + + batch_size = context.batch_size + prompt_len = context.prompt_len + + scheduler_config = llm.llm_engine.scheduler_config + max_num_batched_tokens = scheduler_config.max_num_batched_tokens + max_num_seqs = scheduler_config.max_num_seqs + + if batch_size * prompt_len > max_num_batched_tokens: + print(f"ERROR: chosen batch_size * prompt_len " + f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " + f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " + f"and therefore cannot be run in a single profile step, please " + f"choose a smaller batch size or prompt length, or increase " + f"--max_num_batched_tokens") + sys.exit(-1) + if batch_size >= max_num_seqs: + print( + f"ERROR: chosen batch_size ({batch_size}) is larger than " + f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " + f"single profile step, please choose a smaller batch size") + sys.exit(-1) + + for i in range(batch_size): + llm.llm_engine.add_request( + request_id=f"seq{i}", + prompt=None, + prompt_token_ids=torch.randint( + 128, # 128 to skip over special tokens + llm.llm_engine.model_config.get_vocab_size() // 2, + size=(prompt_len, )).tolist(), + sampling_params=sampling_params) + + with nm_profile() as prefill_prof: + llm.llm_engine.step() # First step is prefill + + with nm_profile() as decode_prof: + llm.llm_engine.step() + + prefill_results = prefill_prof.results + decode_results = decode_prof.results + + print("=" * 80) + print(f"= Prefill Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + prefill_results.print_model_table() + print() + print("=" * 80) + print(f"= Decode Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results.print_model_table() + print() + print("=" * 80) + print(f"= Prefill Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + prefill_results.print_summary_table() + print() + print("=" * 80) + print(f"= Decode Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results.print_summary_table() + + if csv_output: + csv_filename_base = csv_output.rstrip(".csv") + prefill_results.export_model_stats_table_csv( + csv_filename_base + "_prefill_model_table.csv") + prefill_results.export_summary_stats_table_csv( + csv_filename_base + "_prefill_summary_table.csv") + decode_results.export_model_stats_table_csv(\ + csv_filename_base + "_decode_model_table.csv") + decode_results.export_summary_stats_table_csv( + csv_filename_base + "_decode_summary_table.csv") + + if json_output: + cuda_devices = [ + torch.cuda.get_device_properties(dev_idx) + for dev_idx in range(torch.cuda.device_count()) + ] + + json_dict = { + "context": { + "python_version": f"{sys.version}", + "torch_version": f"{torch.__version__}", + "torch_cuda_version": f"{torch.version.cuda}", + "cuda_devices": f"{cuda_devices}", + **asdict(context) + }, + "prefill": prefill_results.convert_stats_to_dict(), + "decode": decode_results.convert_stats_to_dict() + } + + with open(json_output.rstrip(".json") + ".json", "w+") as f: + json.dump(json_dict, f, indent=2) + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model", + type=str, + required=True, + help='The name or path of a HuggingFace Transformers model.') + parser.add_argument("--model-revision", type=str, default=None) + parser.add_argument( + "--csv", + type=str, + default=None, + help="Export the results as multiple csv file. This should be the root " + "filename, will create _prefill_model_table.csv, " + "_prefill_summary_table.csv, " + "_decode_model_table.csv, and " + "_decode_summary_table.csv") + parser.add_argument( + "--json", + type=str, + default=None, + help="Export the results as a json file. This should be the filename") + parser.add_argument( + "--sparsity", + "-s", + type=str, + choices=[None, 'sparse_w16a16', 'semi_structured_sparse_w16a16'], + help="Method used to compress sparse weights. If " + "None, we first check the `sparsity_config` attribute" + "in the model config file. If that is None we assume" + "the model weights are dense") + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=['awq', 'gptq', 'squeezellm', 'marlin', None], + default=None, + help="The method used to quantize the model weights, " + "options are \"marlin\", \"awq\", \"gptq\" and \"squeezellm\"") + parser.add_argument( + "--max-seq-len", + type=int, + default=MAX_SEQ_LEN_DEFAULT, + help=f"Maximum length of a sequence (including prompt and output), " + f"default={MAX_SEQ_LEN_DEFAULT}") + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=None, + help="Maximum number of tokens to be processed in a single iteration. " + " Should be greater than batch-size * prompt-len so the prefill can " + " run in a single iteration.") + parser.add_argument( + "--prompt-len", + type=int, + default=PROMPT_LEN_DEFAULT, + help=f"Length of the random prompt to use when profiling, all batched " + f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") + parser.add_argument("--batch-size", + type=int, + default=BATCH_SIZE_DEFAULT, + help=f"Number of requests to run as a single batch, " + f"default={BATCH_SIZE_DEFAULT}") + parser.add_argument("--tensor-parallel-size", + "-tp", + type=int, + default=1, + help="Number of GPUs to use i.e. tensor parallelism, " + "default=1") + parser.add_argument( + "--allow-cuda-graphs", + action='store_true', + help="Enables cuda graphs to be used, well remove a lot of the module " + "level info in the profiler results since almost everything runs in " + "the graph where we do not have access to an informative stack trace") + + args = parser.parse_args() + + context = ProfileContext( + **{ + k: v + for k, v in vars(args).items() + if k in inspect.signature(ProfileContext).parameters + }) + run_profile(context, csv_output=args.csv, json_output=args.json) diff --git a/neuralmagic/tools/profiler/print_table.py b/neuralmagic/tools/profiler/print_table.py new file mode 100644 index 0000000000000..728a5e3860895 --- /dev/null +++ b/neuralmagic/tools/profiler/print_table.py @@ -0,0 +1,77 @@ +import argparse +import json + +from vllm.profiler.nm_profile import SummaryStatsEntry, ModelStatsEntry +from vllm.profiler.utils import indent_string, TablePrinter +from typing import Dict + + +def flatten_entries(entry_cls, profile_dict: Dict): + entries_and_depth = [] + + def get_entries(node, curr_depth=0): + entries_and_depth.append((entry_cls(**node["entry"]), curr_depth)) + + for child in node["children"]: + get_entries( + child, + curr_depth=curr_depth + 1, + ) + + for root in profile_dict: + get_entries(root) + + return entries_and_depth + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--json-trace", + type=str, + required=True, + help="json trace file output by " + "examples/offline_profile.py") + parser.add_argument("--phase", + type=str, + choices=["prefill", "decode"], + required=True, + help="The phase to print the table for.") + parser.add_argument("--table", + type=str, + choices=["summary", "model"], + default="summary", + help="Which table to print, the summary table or the " + "layerwise model table") + + args = parser.parse_args() + + with open(args.json_trace, "r") as f: + profile_data = json.load(f) + + if args.table == "summary": + entries_and_depths = flatten_entries( + SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) + column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + elif args.table == "model": + entries_and_depths = flatten_entries( + ModelStatsEntry, profile_data[args.phase]["model_stats"]) + column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + + # ident entry names based on the depth + entries = [] + for entry, depth in entries_and_depths: + entry.name = indent_string( + entry.name, + indent=depth, + indent_style=lambda indent: "|" + "-" * indent + " ") + entries.append(entry) + + TablePrinter(type(entries[0]), column_widths).print_table(entries) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py new file mode 100644 index 0000000000000..9d0f7f3285436 --- /dev/null +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -0,0 +1,209 @@ +import argparse +import json +import pandas as pd +import matplotlib.pyplot as plt + + +def trim_string_back(string: str, width: int): + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +def abbreviate_known_names(name: str): + abbreviations = { + "MergedColumnParallelLinear": "MCPLinear", + "QKVParallelLinear": "QKVPLinear", + "RowParallelLinear": "RPLinear", + "weight=": "w=", + "bfloat16": "bf16", + "float16": "f16", + } + for key, value in abbreviations.items(): + name = name.replace(key, value) + return name + + +def shorten_plot_legend_strings(legend, max_char_len: int): + for t in legend.get_texts(): + t.set_text( + trim_string_back(abbreviate_known_names(t.get_text()), + max_char_len)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by examples/offline_profile.py") + parser.add_argument( + "--output", + type=str, + required=False, + help="Output figure file, should be a image file such as pdf, " + "jpeg, png, etc., defaults to .pdf") + parser.add_argument("--level", + type=str, + default="module", + choices=["module", "kernel"]) + parser.add_argument("--top_k", + type=int, + default=9, + help="Only graph the top `top_k` entries by time.") + parser.add_argument("--ignore_sampler", + action='store_true', + help="Ignore everything under the \"Sampler\" module") + + args = parser.parse_args() + + ignore_sampler = args.ignore_sampler + make_names_unique = False + top_k = args.top_k + + if args.level == "module": + depth = -2 + make_names_unique = True + elif args.level == "kernel": + depth = -1 + else: + raise Exception(f"Unexpected level value ({args.level})") + + if ignore_sampler: + print("WARNING: ignoring Sampler time so the pct_cuda_time will not " + "add up to 100%") + + json_trace = args.json_trace + output = args.output if args.output else json_trace.strip(".json") + ".pdf" + + with open(json_trace, "r") as f: + profile_data = json.load(f) + + prefill_entries_and_traces = [] + decode_entries_and_traces = [] + + def largest_dist_from_leaf(node, depth=0): + if len(node["children"]) == 0: + return depth + return max([ + largest_dist_from_leaf(child, depth=depth + 1) + for child in node["children"] + ]) + + def get_entries_at_depth(depth, + entries_and_traces, + node, + curr_depth=0, + trace=()): + if ignore_sampler and node["entry"]["name"] == "Sampler": + return + + if (depth >= 0 and depth == curr_depth) or ( + depth < 0 + and largest_dist_from_leaf(node) == (abs(depth) - 1)): + entries_and_traces.append((node["entry"], trace)) + trace = (node["entry"]["name"], ) + trace + for child in node["children"]: + get_entries_at_depth(depth, + entries_and_traces, + child, + curr_depth=curr_depth + 1, + trace=trace) + + for root in profile_data["prefill"]["summary_stats"]: + get_entries_at_depth(depth, prefill_entries_and_traces, root) + for root in profile_data["decode"]["summary_stats"]: + get_entries_at_depth(depth, decode_entries_and_traces, root) + + def attempt_to_make_names_unique(entries_and_traces): + names, non_unique_names = (set(), set()) + + def all_the_same(items) -> bool: + return all(i == items[0] for i in items) + + for entry, _ in entries_and_traces: + if entry["name"] in names: + non_unique_names.add(entry["name"]) + else: + names.add(entry["name"]) + + for name in non_unique_names: + entries_and_traces_with_name = [ + (entry, trace) for entry, trace in entries_and_traces + if entry["name"] == name + ] + + zipped_traces = list( + zip(*[trace for _, trace in entries_and_traces_with_name])) + first_trace_difference = next( + (i for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles)), None) + + if first_trace_difference is None: + # can't create a unique name, leave them names as the + # are they will get aggregated by the pivot_table call + continue + + for entry, trace in entries_and_traces_with_name: + entry["name"] = " <- ".join((entry["name"], ) + + trace[:first_trace_difference + 1]) + + if make_names_unique: + attempt_to_make_names_unique(prefill_entries_and_traces) + attempt_to_make_names_unique(decode_entries_and_traces) + + def keep_only_top_entries(df, metric, top_k=9): + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, + ["name"]] = "others" + + prefill_df = pd.DataFrame( + [entry for entry, _ in prefill_entries_and_traces]) + prefill_df["phase"] = "prefill" + decode_df = pd.DataFrame([entry for entry, _ in decode_entries_and_traces]) + decode_df["phase"] = "decode" + + if top_k: + keep_only_top_entries(prefill_df, "cuda_time_us", top_k) + keep_only_top_entries(decode_df, "cuda_time_us", top_k) + + df = pd.concat([prefill_df, decode_df]) + df["cuda_time_ms"] = df["cuda_time_us"] / 1000 + + fig, axes = plt.subplots(2, figsize=(5, 8), sharex=True) + + def plot_metric(metric: str, ax, add_totals=False): + pivoted_df = df.pivot_table(index="phase", + columns="name", + values=metric, + aggfunc="sum") + pivoted_df.plot.bar(stacked=True, legend=False, ax=ax) + ax.set_ylabel(metric) + + if add_totals: + ax.bar_label(ax.containers[-1]) + + plot_metric("cuda_time_ms", ax=axes[0], add_totals=True) + plot_metric("pct_cuda_time", ax=axes[1]) + + handles, labels = plt.gca().get_legend_handles_labels() + legend = fig.legend(handles, + labels, + loc='center left', + bbox_to_anchor=(0.93, 0.5)) + shorten_plot_legend_strings(legend, 50) + + context = profile_data["context"] + plt.suptitle( + f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + context['sparsity'] if context['sparsity'] else ''}" + ) + plt.savefig(output, bbox_inches='tight') + print("Created: ", output) diff --git a/requirements-dev.txt b/requirements-dev.txt index 00fa132b14c21..53b0af9a92f4c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -28,3 +28,8 @@ peft # Benchmarking aiohttp + +# Profiling +matplotlib +pandas +pyarrow diff --git a/vllm/profiler/__init__.py b/vllm/profiler/__init__.py new file mode 100644 index 0000000000000..93ec4a800e600 --- /dev/null +++ b/vllm/profiler/__init__.py @@ -0,0 +1,5 @@ +from .nm_profile import nm_profile + +__all__ = [ + "nm_profile", +] diff --git a/vllm/profiler/nm_profile.py b/vllm/profiler/nm_profile.py new file mode 100644 index 0000000000000..1912a300c02ba --- /dev/null +++ b/vllm/profiler/nm_profile.py @@ -0,0 +1,346 @@ +import pandas as pd +import copy + +from collections import defaultdict +from dataclasses import dataclass, field, asdict +from vllm.profiler.utils import (indent_string, TablePrinter, event_has_module, + event_is_torch_op, event_module_repr, + event_torch_op_stack_trace) +from typing import Dict, List, Union, Optional, Tuple, Callable, TypeAlias +from torch.profiler import profile, ProfilerActivity +from torch.autograd.profiler import FunctionEvent +from torch._C._autograd import _ProfilerResult, _KinetoEvent, DeviceType +from torch._C._profiler import _EventType, _ProfilerEvent, _ExperimentalConfig + + +@dataclass +class _ModuleTreeNode: + event: _ProfilerEvent + parent: Optional['_ModuleTreeNode'] = None + children: List['_ModuleTreeNode'] = field(default_factory=list) + trace: str = "" + + @property + def is_leaf(self): + return (self.event.children is None or len(self.event.children) == 0) + + @property + def is_torch_op(self): + return event_is_torch_op(self.event) + + @property + def is_cuda(self): + return (self.event.tag == _EventType.Kineto + and self.event.typed[1].device_type == DeviceType.CUDA) + + +@dataclass +class SummaryStatsEntry: + name: str + cuda_time_us: float + pct_cuda_time: float + invocations: int + + +@dataclass +class ModelStatsEntry: + name: str + cpu_time_us: float + cuda_time_us: float + pct_cuda_time: float + trace: str + + +StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry] + + +@dataclass +class _StatsTreeNode: + entry: StatsEntry + children: List[StatsEntry] + parent: Optional[StatsEntry] + + +@dataclass +class NMProfileResults(profile): + _kineto_results: _ProfilerResult + _kineto_event_correlation_map: Dict[int, + List[_KinetoEvent]] = field(init=False) + _event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) + _module_tree: List[_ModuleTreeNode] = field(init=False) + _model_stats_tree: List[_StatsTreeNode] = field(init=False) + _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + + def __post_init__(self): + self._build_correlation_map() + self._build_module_tree() + self._build_stats_trees() + + def print_model_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + if column_widths: + _column_widths.update(**column_widths) + filtered_model_table = [ + (depth, row) + for depth, row in self._flatten_stats_tree(self._model_stats_tree) + if row.cuda_time_us > 0 or row.cpu_time_us > 0 + ] + TablePrinter(ModelStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_model_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def print_summary_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + if column_widths: + _column_widths.update(**column_widths) + filtered_summary_table = [(depth, row) + for depth, row in self._flatten_stats_tree( + self._summary_stats_tree) + if row.cuda_time_us > 0] + TablePrinter(SummaryStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_summary_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def export_model_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._model_stats_tree) + ]) + df.to_csv(filename) + + def export_summary_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._summary_stats_tree) + ]) + df.to_csv(filename) + + def convert_stats_to_dict(self) -> str: + return { + "summary_stats": + self._convert_stats_tree_to_dict(self._summary_stats_tree), + "model_stats": + self._convert_stats_tree_to_dict(self._model_stats_tree) + } + + @staticmethod + def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, + StatsEntry]], + indent_style: Union[Callable[[int], + str], + str] = " "): + indented_rows = [] + for depth, row in depths_rows: + if row.cuda_time_us == 0: + continue + indented_row = copy.deepcopy(row) + indented_row.name = indent_string(indented_row.name, depth, + indent_style) + indented_rows.append(indented_row) + return indented_rows + + def _build_correlation_map(self): + self._kineto_event_correlation_map = defaultdict(list) + for event in self._kineto_results.events(): + self._kineto_event_correlation_map[event.correlation_id()].append( + event) + + def _build_module_tree(self): + self._module_tree = [] + event_tree = self._kineto_results.experimental_event_tree() + + def _df_traversal(event: _ProfilerEvent, + curr_node: Optional[_ModuleTreeNode] = None): + if event_has_module(event): + node = _ModuleTreeNode(event=event, parent=curr_node) + if curr_node: + curr_node.children.append(node) + else: + self._module_tree.append(node) + curr_node = node + + is_leaf = (event.children is None or len(event.children) == 0) + if is_leaf and curr_node: + node = _ModuleTreeNode( + event=event, + parent=curr_node, + trace=event_torch_op_stack_trace( + event, until=lambda x: event_has_module(x))) + curr_node.children.append(node) + curr_node = node + + for child in event.children: + _df_traversal(child, curr_node) + + for root in event_tree: + _df_traversal(root) + + def _get_kineto_gpu_event(self, node: _ModuleTreeNode): + if node.event.tag != _EventType.Kineto: + return None + correlated_kineto_events = self._kineto_event_correlation_map.get( + node.event.correlation_id, []) + iterator = (x for x in correlated_kineto_events + if x.device_type() == DeviceType.CUDA) + return next(iterator, None) + + def _cumulative_cuda_time(self, node: _ModuleTreeNode): + + def _cumulative_cuda_time_recursive(node: _ModuleTreeNode): + if node.is_leaf and (gpu_kineto_event := + self._get_kineto_gpu_event(node)): + return gpu_kineto_event.duration_us() + else: + cumulative_cuda_time = 0 + for child in node.children: + cumulative_cuda_time += _cumulative_cuda_time_recursive( + child) + return cumulative_cuda_time + + return _cumulative_cuda_time_recursive(node) + + def _total_cuda_time(self): + return sum( + [self._cumulative_cuda_time(root) for root in self._module_tree]) + + def _build_stats_trees(self): + summary_dict: Dict[str, self.StatsTreeNode] = {} + total_cuda_time = self._total_cuda_time() + + def pct_cuda_time(cuda_time_us): + return (cuda_time_us / total_cuda_time) * 100 + + def build_summary_stats_tree_df( + node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None, + summary_trace: Tuple[str] = ()): + + if event_has_module(node.event): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_us() + else: + return None + + summary_trace = summary_trace + (name, ) + if summary_trace in summary_dict: + entry = summary_dict[summary_trace].entry + entry.cuda_time_us += cuda_time_us + entry.invocations += 1 + entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us) + else: + new_node = _StatsTreeNode(entry=SummaryStatsEntry( + name=name, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + invocations=1), + children=[], + parent=parent) + if parent: + parent.children.append(new_node) + summary_dict[summary_trace] = new_node + + for child in node.children: + build_summary_stats_tree_df(child, summary_dict[summary_trace], + summary_trace) + + return summary_dict[summary_trace] + + self._summary_stats_tree = [] + for root in self._module_tree: + self._summary_stats_tree.append(build_summary_stats_tree_df(root)) + + def build_model_stats_tree_df(node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None): + if event_has_module(node.event, ): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + cpu_time_us = node.event.duration_time_ns / 1000 + trace = "" + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_us() + cpu_time_us = 0 + trace = node.trace + else: + return None + + new_node = _StatsTreeNode(entry=ModelStatsEntry( + name=name, + cpu_time_us=cpu_time_us, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + trace=trace), + parent=parent, + children=[]) + if parent: + parent.children.append(new_node) + + for child in node.children: + build_model_stats_tree_df(child, new_node) + + return new_node + + self._model_stats_tree = [] + for root in self._module_tree: + self._model_stats_tree.append(build_model_stats_tree_df(root)) + + def _flatten_stats_tree( + self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: + entries: List[Tuple[int, StatsEntry]] = [] + + def df_traversal(node: _StatsTreeNode, depth=0): + entries.append((depth, node.entry)) + for child in node.children: + df_traversal(child, depth=depth + 1) + + for root in tree: + df_traversal(root) + + return entries + + def _convert_stats_tree_to_dict(self, + tree: List[_StatsTreeNode]) -> List[Dict]: + root_dicts: List[Dict] = [] + + def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): + curr_json_list.append({ + "entry": asdict(node.entry), + "children": [] + }) + for child in node.children: + df_traversal(child, curr_json_list[-1]["children"]) + + for root in tree: + df_traversal(root, root_dicts) + + return root_dicts + + +class nm_profile(profile): + + def __init__(self): + super().__init__( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + with_modules=True, + experimental_config=_ExperimentalConfig(verbose=True)) + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + self.results = NMProfileResults(self.profiler.kineto_results) diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py new file mode 100644 index 0000000000000..f8ead593d178b --- /dev/null +++ b/vllm/profiler/utils.py @@ -0,0 +1,146 @@ +import dataclasses + +from typing import Callable, Dict, Type, List, Union + +from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata + +# +# String / Print Manipulation +# + + +def trim_string_front(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[offset:] + if len(string) > 3: + string = "..." + string[3:] + return string + + +def trim_string_back(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +class TablePrinter: + + def __init__(self, row_cls: Type[dataclasses.dataclass], + column_widths: Dict[str, int]): + self.row_cls = row_cls + self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] + self.column_widths = column_widths + assert set(self.column_widths.keys()) == set(self.fieldnames) + + def print_table(self, rows: List[dataclasses.dataclass]): + self._print_header() + self._print_line() + for row in rows: + self._print_row(row) + + def _print_header(self): + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + print(trim_string_back(f, col_width).ljust(col_width), + end=" | " if not last else "\n") + + def _print_row(self, row): + assert isinstance(row, self.row_cls) + + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + val = getattr(row, f) + + val_str = "" + if isinstance(val, str): + val_str = trim_string_back(val, col_width).ljust(col_width) + elif type(val) in [float, int]: + val_str = f"{float(val):>.2f}".rjust(col_width) + else: + val_str = f"{val}".rjust(col_width) + print(val_str, end=" | " if not last else "\n") + + def _print_line(self): + total_col_width = 0 + for column_width in self.column_widths.values(): + total_col_width += column_width + print("=" * (total_col_width + 3 * (len(self.column_widths) - 1))) + + +def indent_string(string: str, + indent: int, + indent_style: Union[Callable[[int], str], str] = " ") -> str: + if indent: + if isinstance(indent_style, str): + return indent_style * indent + string + else: + return indent_style(indent) + string + else: + return string + + +# +# _ProfilerEvent utils +# + + +def event_has_module(event: _ProfilerEvent) -> bool: + event_type, typed_event = event.typed + if event_type == _EventType.PyCall: + return typed_event.module is not None + return False + + +def event_is_torch_op(event: _ProfilerEvent) -> bool: + return event.tag == _EventType.TorchOp + + +def event_arg_repr(arg) -> str: + if arg is None or type(arg) in [float, int, bool, str]: + return f"{arg}" + elif isinstance(arg, list): + return f"[{', '.join([event_arg_repr(x) for x in arg])}]" + elif isinstance(arg, tuple): + return f"({', '.join([event_arg_repr(x) for x in arg])})" + else: + assert isinstance(arg, + _TensorMetadata), f"Unsupported type: {type(arg)}" + sizes_str = ', '.join([str(x) for x in arg.sizes]) + return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]" + + +def event_torch_op_repr(event: _ProfilerEvent) -> str: + assert event.tag == _EventType.TorchOp + args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs]) + return f"{event.name}({args_str})".replace("aten::", "") + + +def event_module_repr(event: _ProfilerEvent) -> str: + assert event_has_module(event) + module = event.typed[1].module + if module.parameters and len(module.parameters) > 0: + args_str = ', '.join( + [f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters]) + return f"{module.cls_name}({args_str})" + else: + return module.cls_name + + +def event_torch_op_stack_trace(curr_event: _ProfilerEvent, + until: Callable[[_ProfilerEvent], bool]) -> str: + trace = "" + curr_event = curr_event.parent + while curr_event and not until(curr_event): + if event_is_torch_op(curr_event): + if len(trace) > 0: + trace += " <- " + trace += event_torch_op_repr(curr_event) + curr_event = curr_event.parent + + return trace diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1ef783da6d08e..5a89de43c61ba 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,3 +1,5 @@ +# This file has been modified by Neural Magic + import contextlib import dataclasses import time @@ -758,9 +760,11 @@ def vocab_size(self) -> int: return self.model_config.get_vocab_size() -class CUDAGraphRunner: +class CUDAGraphRunner(nn.Module): def __init__(self, model: nn.Module): + super().__init__() + self.model = model self.graph = None self.input_buffers: Dict[str, torch.Tensor] = {} @@ -839,9 +843,6 @@ def forward( # Return the output tensor. return self.output_buffers["hidden_states"] - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - @contextlib.contextmanager def _maybe_cupy_nccl():