From 89a0e3c4c43682d3a95db2f2c4f0bd8e58f737e5 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 26 Jun 2024 11:32:42 -0400 Subject: [PATCH] [Bugfix] Update profile example to new add request interface + fix profiler not picking up kernels within cudagraphs (#332) This PR updates `examples/offline_profile.py` to use the new `add_requests` interface. In addition, this PR fixes issues with profiler not picking up kernels run from within a cudagraph (i.e. when profiling with `--allow-cuda-graphs`, there's is too main issues: 1) Changes from the initial profiler PR (https://github.com/neuralmagic/nm-vllm/pull/124) were wiped out by https://github.com/neuralmagic/nm-vllm/pull/224, namely the changes in `model_runner.py` converting `CUDAGraphRunner` to a `nn.Module` allowing the profiler to pick it up 2) Many kernels within the graph had the same correlation id so we were always picking the first of potentially many kernels to display, using name in addition to correlation Id sees to resolve this issue but is potentially fragile Before the PR: ``` ================================================================================ = Decode Summary Table (prompt_len=1, batch_size=1) ================================================================================ name | cuda_time_us | pct_cuda_... | invocations ================================================================================================================================ LogitsProcessor | 350.00 | 57.76 | 1.00 |- void at::native::(anonymous namespace)::indexSelectSmallIndex Device) | 18.00 | 2.97 | 9.00 |- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... | 4.00 | 0.66 | 1.00 |- void at::native::unrolled_elementwise_kernel Pageable) | 10.00 | 1.65 | 5.00 |- void (anonymous namespace)::elementwise_kernel_with_index(unsigned int*, u... | 1.00 | 0.17 | 1.00 |- void at::native::mbtopk::radixFindKthValues(unsig... | 8.00 | 1.32 | 4.00 |- void at::native::mbtopk::computeBlockwiseKthCounts(unsigned ... | 1.00 | 0.17 | 1.00 |- void at_cuda_detail::cub::DeviceScanByKeyInitKernel(at::cuda::deta... | 4.00 | 0.66 | 1.00 ``` After PR ``` name | cuda_time_us | pct_cuda_... | invocations ================================================================================================================================ CUDAGraphRunner | 4238.00 | 84.41 | 1.00 |- Memcpy DtoD (Device -> Device) | 5.00 | 0.10 | 5.00 |- void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<... | 2.00 | 0.04 | 2.00 |- void at::native::(anonymous namespace)::indexSelectSmallIndex(c10::BFloat16*, c10::BFloat16 co... | 4.00 | 0.08 | 1.00 |- void vllm::scaled_fp8_quant_kernel(c10::Float8_e4m3fn*, c10... | 256.00 | 5.10 | 128.00 |- sm90_xmma_gemm_e4m3bf16_e4m3f32_f32_tn_n_tilesize64x64x128_warpgroupsize1x... | 1440.00 | 28.68 | 96.00 |- void vllm::rotary_embedding_kernel(long const*, c10::... | 96.00 | 1.91 | 32.00 |- void vllm::reshape_and_cache_flash_kernel(c10::BFloat16 con... | 64.00 | 1.27 | 32.00 |- void flash_fwd_splitkv_kernel(0)))&&vllm::_typeConvert::exists, voi... | 128.00 | 2.55 | 64.00 |- sm90_xmma_gemm_e4m3bf16_e4m3f32_f32_tn_n_tilesize64x128x128_warpgroupsize1... | 1664.00 | 33.14 | 32.00 |- void vllm::act_and_mul_kernel Device) | 18.00 | 0.36 | 9.00 |- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... | 4.00 | 0.08 | 1.00 |- at::native::(anonymous namespace)::fill_reverse_indices_kernel(long*, int,... | 2.00 | 0.04 | 1.00 |- void at_cuda_detail::cub::DeviceRadixSortUpsweepKernel Device) | 1.00 | 0.02 | 1.00 |- void at::native::unrolled_elementwise_kernel... | 1.00 | 0.02 | 1.00 |- void (anonymous namespace)::elementwise_kernel_with_index Pageable) | 10.00 | 0.20 | 5.00 |- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... | 3.00 | 0.06 | 1.00 |- void at::native::reduce_kernel<512, 1, at::native::ReduceOp(unsigned int*, u... | 1.00 | 0.02 | 1.00 |- void at::native::mbtopk::radixFindKthValues(unsig... | 8.00 | 0.16 | 4.00 |- void at::native::mbtopk::computeBlockwiseKthCounts(unsigned ... | 1.00 | 0.02 | 1.00 |- void at_cuda_detail::cub::DeviceScanByKeyInitKernel(at::cuda::deta... | 4.00 | 0.08 | 1.00 ``` --------- Co-authored-by: Lucas Wilkinson --- examples/offline_profile.py | 17 ++++++++++------- vllm/profiler/nm_profile.py | 3 ++- vllm/worker/model_runner.py | 11 +++++++---- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/offline_profile.py b/examples/offline_profile.py index 054c438036eb0..1c95b5bed451c 100644 --- a/examples/offline_profile.py +++ b/examples/offline_profile.py @@ -36,7 +36,10 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], print(f" {key} = {value}") # Create sampling params - sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=8, + ignore_eos=True) # Create LLM llm = LLM( @@ -74,14 +77,14 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], sys.exit(-1) for i in range(batch_size): + prompt_token_ids = torch.randint( + llm.llm_engine.model_config.get_vocab_size(), + size=(prompt_len, )).tolist() + llm.llm_engine.add_request( request_id=f"seq{i}", - prompt=None, - prompt_token_ids=torch.randint( - 128, # 128 to skip over special tokens - llm.llm_engine.model_config.get_vocab_size() // 2, - size=(prompt_len, )).tolist(), - sampling_params=sampling_params) + inputs={'prompt_token_ids': prompt_token_ids}, + params=sampling_params) with nm_profile() as prefill_prof: llm.llm_engine.step() # First step is prefill diff --git a/vllm/profiler/nm_profile.py b/vllm/profiler/nm_profile.py index bc357978bec54..0cc98cd17f40e 100644 --- a/vllm/profiler/nm_profile.py +++ b/vllm/profiler/nm_profile.py @@ -191,7 +191,8 @@ def _get_kineto_gpu_event(self, node: _ModuleTreeNode): correlated_kineto_events = self._kineto_event_correlation_map.get( node.event.correlation_id, []) iterator = (x for x in correlated_kineto_events - if x.device_type() == DeviceType.CUDA) + if x.device_type() == DeviceType.CUDA + and x.name() == node.event.name) return next(iterator, None) def _cumulative_cuda_time(self, node: _ModuleTreeNode): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a321eafce1a2f..ff8df9b3ea56e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,3 +1,5 @@ +# This file has been modified by Neural Magic + import gc import time import warnings @@ -986,9 +988,13 @@ def vocab_size(self) -> int: return self.model_config.get_vocab_size() -class CUDAGraphRunner: +# NOTE: this is nn.Module so the profiler can properly capture/group +# kernels calls made within the graph +class CUDAGraphRunner(nn.Module): def __init__(self, model: nn.Module): + super().__init__() + self.model = model self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} @@ -1084,9 +1090,6 @@ def forward( # Return the output tensor. return self.output_buffers["hidden_states"] - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size.