|
| 1 | +import argparse |
| 2 | +import torch |
| 3 | +import sys |
| 4 | +import json |
| 5 | +import inspect |
| 6 | + |
| 7 | +from dataclasses import dataclass, asdict |
| 8 | +from typing import Optional |
| 9 | +from vllm import LLM, SamplingParams |
| 10 | +from vllm.profiler import nm_profile |
| 11 | + |
| 12 | +BATCH_SIZE_DEFAULT = 1 |
| 13 | +PROMPT_LEN_DEFAULT = 256 |
| 14 | +MAX_SEQ_LEN_DEFAULT = 1024 |
| 15 | + |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class ProfileContext: |
| 19 | + model: str |
| 20 | + model_revision: str |
| 21 | + sparsity: str |
| 22 | + quantization: str |
| 23 | + max_seq_len: int |
| 24 | + max_num_batched_tokens: int |
| 25 | + prompt_len: int |
| 26 | + batch_size: int |
| 27 | + tensor_parallel_size: int |
| 28 | + allow_cuda_graphs: bool |
| 29 | + |
| 30 | + |
| 31 | +def run_profile(context: ProfileContext, csv_output: Optional[str], |
| 32 | + json_output: Optional[str]): |
| 33 | + print("Run profile with:") |
| 34 | + for key, value in asdict(context).items(): |
| 35 | + print(f" {key} = {value}") |
| 36 | + |
| 37 | + # Create sampling params |
| 38 | + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8) |
| 39 | + |
| 40 | + # Create LLM |
| 41 | + llm = LLM( |
| 42 | + model=context.model, |
| 43 | + revision=context.model_revision, |
| 44 | + sparsity=context.sparsity, |
| 45 | + enforce_eager=not context.allow_cuda_graphs, |
| 46 | + tensor_parallel_size=context.tensor_parallel_size, |
| 47 | + gpu_memory_utilization=0.9, |
| 48 | + max_model_len=context.max_seq_len, |
| 49 | + quantization=context.quantization, |
| 50 | + max_num_batched_tokens=context.max_num_batched_tokens, |
| 51 | + ) |
| 52 | + |
| 53 | + batch_size = context.batch_size |
| 54 | + prompt_len = context.prompt_len |
| 55 | + |
| 56 | + scheduler_config = llm.llm_engine.scheduler_config |
| 57 | + max_num_batched_tokens = scheduler_config.max_num_batched_tokens |
| 58 | + max_num_seqs = scheduler_config.max_num_seqs |
| 59 | + |
| 60 | + if batch_size * prompt_len > max_num_batched_tokens: |
| 61 | + print(f"ERROR: chosen batch_size * prompt_len " |
| 62 | + f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " |
| 63 | + f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " |
| 64 | + f"and therefore cannot be run in a single profile step, please " |
| 65 | + f"choose a smaller batch size or prompt length, or increase " |
| 66 | + f"--max_num_batched_tokens") |
| 67 | + sys.exit(-1) |
| 68 | + if batch_size >= max_num_seqs: |
| 69 | + print( |
| 70 | + f"ERROR: chosen batch_size ({batch_size}) is larger than " |
| 71 | + f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " |
| 72 | + f"single profile step, please choose a smaller batch size") |
| 73 | + sys.exit(-1) |
| 74 | + |
| 75 | + for i in range(batch_size): |
| 76 | + llm.llm_engine.add_request( |
| 77 | + request_id=f"seq{i}", |
| 78 | + prompt=None, |
| 79 | + prompt_token_ids=torch.randint( |
| 80 | + 128, # 128 to skip over special tokens |
| 81 | + llm.llm_engine.model_config.get_vocab_size() // 2, |
| 82 | + size=(prompt_len, )).tolist(), |
| 83 | + sampling_params=sampling_params) |
| 84 | + |
| 85 | + with nm_profile() as prefill_prof: |
| 86 | + llm.llm_engine.step() # First step is prefill |
| 87 | + |
| 88 | + with nm_profile() as decode_prof: |
| 89 | + llm.llm_engine.step() |
| 90 | + |
| 91 | + prefill_results = prefill_prof.results |
| 92 | + decode_results = decode_prof.results |
| 93 | + |
| 94 | + print("=" * 80) |
| 95 | + print(f"= Prefill Model Table " |
| 96 | + f"(prompt_len={prompt_len}, batch_size={batch_size})") |
| 97 | + print("=" * 80) |
| 98 | + print() |
| 99 | + prefill_results.print_model_table() |
| 100 | + print() |
| 101 | + print("=" * 80) |
| 102 | + print(f"= Decode Model Table " |
| 103 | + f"(prompt_len={prompt_len}, batch_size={batch_size})") |
| 104 | + print("=" * 80) |
| 105 | + print() |
| 106 | + decode_results.print_model_table() |
| 107 | + print() |
| 108 | + print("=" * 80) |
| 109 | + print(f"= Prefill Summary Table " |
| 110 | + f"(prompt_len={prompt_len}, batch_size={batch_size})") |
| 111 | + print("=" * 80) |
| 112 | + print() |
| 113 | + prefill_results.print_summary_table() |
| 114 | + print() |
| 115 | + print("=" * 80) |
| 116 | + print(f"= Decode Summary Table " |
| 117 | + f"(prompt_len={prompt_len}, batch_size={batch_size})") |
| 118 | + print("=" * 80) |
| 119 | + print() |
| 120 | + decode_results.print_summary_table() |
| 121 | + |
| 122 | + if csv_output: |
| 123 | + csv_filename_base = csv_output.rstrip(".csv") |
| 124 | + prefill_results.export_model_stats_table_csv( |
| 125 | + csv_filename_base + "_prefill_model_table.csv") |
| 126 | + prefill_results.export_summary_stats_table_csv( |
| 127 | + csv_filename_base + "_prefill_summary_table.csv") |
| 128 | + decode_results.export_model_stats_table_csv(\ |
| 129 | + csv_filename_base + "_decode_model_table.csv") |
| 130 | + decode_results.export_summary_stats_table_csv( |
| 131 | + csv_filename_base + "_decode_summary_table.csv") |
| 132 | + |
| 133 | + if json_output: |
| 134 | + cuda_devices = [ |
| 135 | + torch.cuda.get_device_properties(dev_idx) |
| 136 | + for dev_idx in range(torch.cuda.device_count()) |
| 137 | + ] |
| 138 | + |
| 139 | + json_dict = { |
| 140 | + "context": { |
| 141 | + "python_version": f"{sys.version}", |
| 142 | + "torch_version": f"{torch.__version__}", |
| 143 | + "torch_cuda_version": f"{torch.version.cuda}", |
| 144 | + "cuda_devices": f"{cuda_devices}", |
| 145 | + **asdict(context) |
| 146 | + }, |
| 147 | + "prefill": prefill_results.convert_stats_to_dict(), |
| 148 | + "decode": decode_results.convert_stats_to_dict() |
| 149 | + } |
| 150 | + |
| 151 | + with open(json_output.rstrip(".json") + ".json", "w+") as f: |
| 152 | + json.dump(json_dict, f, indent=2) |
| 153 | + pass |
| 154 | + |
| 155 | + |
| 156 | +if __name__ == "__main__": |
| 157 | + parser = argparse.ArgumentParser() |
| 158 | + |
| 159 | + parser.add_argument( |
| 160 | + "--model", |
| 161 | + type=str, |
| 162 | + required=True, |
| 163 | + help='The name or path of a HuggingFace Transformers model.') |
| 164 | + parser.add_argument("--model-revision", type=str, default=None) |
| 165 | + parser.add_argument( |
| 166 | + "--csv", |
| 167 | + type=str, |
| 168 | + default=None, |
| 169 | + help="Export the results as multiple csv file. This should be the root " |
| 170 | + "filename, will create <filename>_prefill_model_table.csv, " |
| 171 | + "<filename>_prefill_summary_table.csv, " |
| 172 | + "<filename>_decode_model_table.csv, and " |
| 173 | + "<filename>_decode_summary_table.csv") |
| 174 | + parser.add_argument( |
| 175 | + "--json", |
| 176 | + type=str, |
| 177 | + default=None, |
| 178 | + help="Export the results as a json file. This should be the filename") |
| 179 | + parser.add_argument( |
| 180 | + "--sparsity", |
| 181 | + "-s", |
| 182 | + type=str, |
| 183 | + choices=[None, 'sparse_w16a16', 'semi_structured_sparse_w16a16'], |
| 184 | + help="Method used to compress sparse weights. If " |
| 185 | + "None, we first check the `sparsity_config` attribute" |
| 186 | + "in the model config file. If that is None we assume" |
| 187 | + "the model weights are dense") |
| 188 | + parser.add_argument( |
| 189 | + "--quantization", |
| 190 | + "-q", |
| 191 | + type=str, |
| 192 | + choices=['awq', 'gptq', 'squeezellm', 'marlin', None], |
| 193 | + default=None, |
| 194 | + help="The method used to quantize the model weights, " |
| 195 | + "options are \"marlin\", \"awq\", \"gptq\" and \"squeezellm\"") |
| 196 | + parser.add_argument( |
| 197 | + "--max-seq-len", |
| 198 | + type=int, |
| 199 | + default=MAX_SEQ_LEN_DEFAULT, |
| 200 | + help=f"Maximum length of a sequence (including prompt and output), " |
| 201 | + f"default={MAX_SEQ_LEN_DEFAULT}") |
| 202 | + parser.add_argument( |
| 203 | + "--max-num-batched-tokens", |
| 204 | + type=int, |
| 205 | + default=None, |
| 206 | + help="Maximum number of tokens to be processed in a single iteration. " |
| 207 | + " Should be greater than batch-size * prompt-len so the prefill can " |
| 208 | + " run in a single iteration.") |
| 209 | + parser.add_argument( |
| 210 | + "--prompt-len", |
| 211 | + type=int, |
| 212 | + default=PROMPT_LEN_DEFAULT, |
| 213 | + help=f"Length of the random prompt to use when profiling, all batched " |
| 214 | + f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") |
| 215 | + parser.add_argument("--batch-size", |
| 216 | + type=int, |
| 217 | + default=BATCH_SIZE_DEFAULT, |
| 218 | + help=f"Number of requests to run as a single batch, " |
| 219 | + f"default={BATCH_SIZE_DEFAULT}") |
| 220 | + parser.add_argument("--tensor-parallel-size", |
| 221 | + "-tp", |
| 222 | + type=int, |
| 223 | + default=1, |
| 224 | + help="Number of GPUs to use i.e. tensor parallelism, " |
| 225 | + "default=1") |
| 226 | + parser.add_argument( |
| 227 | + "--allow-cuda-graphs", |
| 228 | + action='store_true', |
| 229 | + help="Enables cuda graphs to be used, well remove a lot of the module " |
| 230 | + "level info in the profiler results since almost everything runs in " |
| 231 | + "the graph where we do not have access to an informative stack trace") |
| 232 | + |
| 233 | + args = parser.parse_args() |
| 234 | + |
| 235 | + context = ProfileContext( |
| 236 | + **{ |
| 237 | + k: v |
| 238 | + for k, v in vars(args).items() |
| 239 | + if k in inspect.signature(ProfileContext).parameters |
| 240 | + }) |
| 241 | + run_profile(context, csv_output=args.csv, json_output=args.json) |
0 commit comments