Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bugfix] Update profile example to new add request interface + fix pr…
Browse files Browse the repository at this point in the history
…ofiler not picking up kernels within cudagraphs (#332)

This PR updates `examples/offline_profile.py` to use the new
`add_requests` interface. In addition, this PR fixes issues with
profiler not picking up kernels run from within a cudagraph (i.e. when
profiling with `--allow-cuda-graphs`, there's is too main issues:

1) Changes from the initial profiler PR
(#124) were wiped out by
#224, namely the changes in
`model_runner.py` converting `CUDAGraphRunner` to a `nn.Module` allowing
the profiler to pick it up
2) Many kernels within the graph had the same correlation id so we were
always picking the first of potentially many kernels to display, using
name in addition to correlation Id sees to resolve this issue but is
potentially fragile


Before the PR:
```
================================================================================
= Decode Summary Table (prompt_len=1, batch_size=1)
================================================================================

name                                                                             | cuda_time_us | pct_cuda_... | invocations    
================================================================================================================================
LogitsProcessor                                                                  |       350.00 |        57.76 |            1.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         3.00 |         0.50 |            1.00
|- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x... |       347.00 |        57.26 |            1.00
Sampler                                                                          |       256.00 |        42.24 |            1.00
|- Memcpy HtoD (Pinned -> Device)                                                |        18.00 |         2.97 |            9.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         4.00 |         0.66 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.83 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        54.00 |         8.91 |            1.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        43.00 |         7.10 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         4.00 |         0.66 |            2.00
|- void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_ke... |        13.00 |         2.15 |            4.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::nat... |        29.00 |         4.79 |            1.00
|- Memcpy DtoH (Device -> Pageable)                                              |        10.00 |         1.65 |            5.00
|- void (anonymous namespace)::elementwise_kernel_with_index<int, at::native:... |         1.00 |         0.17 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         3.00 |         0.50 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<long, at::nati... |        15.00 |         2.48 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.17 |            1.00
|- void at::native::mbtopk::fill<unsigned int, unsigned int>(unsigned int*, u... |         1.00 |         0.17 |            1.00
|- void at::native::mbtopk::radixFindKthValues<float, unsigned int, unsigned ... |        36.00 |         5.94 |            4.00
|- void at::native::mbtopk::computeBlockwiseWithinKCounts<unsigned int>(unsig... |         8.00 |         1.32 |            4.00
|- void at::native::mbtopk::computeBlockwiseKthCounts<unsigned int>(unsigned ... |         1.00 |         0.17 |            1.00
|- void at_cuda_detail::cub::DeviceScanByKeyInitKernel<at_cuda_detail::cub::R... |         2.00 |         0.33 |            2.00
|- void at_cuda_detail::cub::DeviceScanByKeyKernel<at_cuda_detail::cub::Devic... |         4.00 |         0.66 |            2.00
|- void at::native::mbtopk::gatherTopK<float, unsigned int, 1>(at::cuda::deta... |         4.00 |         0.66 |            1.00
```

After PR
```
name                                                                             | cuda_time_us | pct_cuda_... | invocations    
================================================================================================================================
CUDAGraphRunner                                                                  |      4238.00 |        84.41 |            1.00
|- Memcpy DtoD (Device -> Device)                                                |         5.00 |         0.10 |            5.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<... |         2.00 |         0.04 |            2.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         2.00 |         0.04 |            1.00
|- void vllm::rms_norm_kernel<c10::BFloat16>(c10::BFloat16*, c10::BFloat16 co... |         4.00 |         0.08 |            1.00
|- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*, c10... |       256.00 |         5.10 |          128.00
|- sm90_xmma_gemm_e4m3bf16_e4m3f32_f32_tn_n_tilesize64x64x128_warpgroupsize1x... |      1440.00 |        28.68 |           96.00
|- void vllm::rotary_embedding_kernel<c10::BFloat16, true>(long const*, c10::... |        96.00 |         1.91 |           32.00
|- void vllm::reshape_and_cache_flash_kernel<c10::BFloat16>(c10::BFloat16 con... |        64.00 |         1.27 |           32.00
|- void flash_fwd_splitkv_kernel<Flash_fwd_kernel_traits<128, 64, 128, 4, fal... |       160.00 |         3.19 |           32.00
|- void flash_fwd_splitkv_combine_kernel<Flash_fwd_kernel_traits<128, 64, 128... |       160.00 |         3.19 |           32.00
|- memcpy32_post                                                                 |        33.00 |         0.66 |           33.00
|- std::enable_if<(((8)>(0)))&&vllm::_typeConvert<c10::BFloat16>::exists, voi... |       128.00 |         2.55 |           64.00
|- sm90_xmma_gemm_e4m3bf16_e4m3f32_f32_tn_n_tilesize64x128x128_warpgroupsize1... |      1664.00 |        33.14 |           32.00
|- void vllm::act_and_mul_kernel<c10::BFloat16, &(c10::BFloat16 vllm::silu_ke... |       224.00 |         4.46 |           32.00
LogitsProcessor                                                                  |       351.00 |         6.99 |            1.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         3.00 |         0.06 |            1.00
|- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x... |       348.00 |         6.93 |            1.00
Sampler                                                                          |       432.00 |         8.60 |            1.00
|- Memcpy HtoD (Pinned -> Device)                                                |        18.00 |         0.36 |            9.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         4.00 |         0.08 |            1.00
|- at::native::(anonymous namespace)::fill_reverse_indices_kernel(long*, int,... |         2.00 |         0.04 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortUpsweepKernel<at_cuda_detail::cub... |        12.00 |         0.24 |            3.00
|- void at_cuda_detail::cub::RadixSortScanBinsKernel<at_cuda_detail::cub::Dev... |         7.00 |         0.14 |            3.00
|- void at_cuda_detail::cub::DeviceRadixSortDownsweepKernel<at_cuda_detail::c... |        46.00 |         0.92 |            3.00
|- Memcpy DtoD (Device -> Device)                                                |         1.00 |         0.02 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.10 |            3.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |         6.00 |         0.12 |            2.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         6.00 |         0.12 |            2.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous n... |         2.00 |         0.04 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<8, c10::BFloat... |        39.00 |         0.78 |            1.00
|- void at_cuda_detail::cub::DeviceScanInitKernel<at_cuda_detail::cub::ScanTi... |         1.00 |         0.02 |            1.00
|- void at_cuda_detail::cub::DeviceScanKernel<at_cuda_detail::cub::DeviceScan... |         6.00 |         0.12 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::FillFunctor<bool>... |         1.00 |         0.02 |            1.00
|- void (anonymous namespace)::elementwise_kernel_with_index<int, at::native:... |         3.00 |         0.06 |            2.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |         5.00 |         0.10 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.10 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        86.00 |         1.71 |            1.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        43.00 |         0.86 |            1.00
|- void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_ke... |        14.00 |         0.28 |            4.00
|- void at::native::(anonymous namespace)::distribution_elementwise_grid_stri... |         3.00 |         0.06 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::BinaryFuncto... |         2.00 |         0.04 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::nat... |        28.00 |         0.56 |            1.00
|- Memcpy DtoH (Device -> Pageable)                                              |        10.00 |         0.20 |            5.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         3.00 |         0.06 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<long, at::nati... |        15.00 |         0.30 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::mbtopk::fill<unsigned int, unsigned int>(unsigned int*, u... |         1.00 |         0.02 |            1.00
|- void at::native::mbtopk::radixFindKthValues<float, unsigned int, unsigned ... |        36.00 |         0.72 |            4.00
|- void at::native::mbtopk::computeBlockwiseWithinKCounts<unsigned int>(unsig... |         8.00 |         0.16 |            4.00
|- void at::native::mbtopk::computeBlockwiseKthCounts<unsigned int>(unsigned ... |         1.00 |         0.02 |            1.00
|- void at_cuda_detail::cub::DeviceScanByKeyInitKernel<at_cuda_detail::cub::R... |         2.00 |         0.04 |            2.00
|- void at_cuda_detail::cub::DeviceScanByKeyKernel<at_cuda_detail::cub::Devic... |         4.00 |         0.08 |            2.00
|- void at::native::mbtopk::gatherTopK<float, unsigned int, 1>(at::cuda::deta... |         4.00 |         0.08 |            1.00
```

---------

Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson and LucasWilkinson authored Jun 26, 2024
1 parent b4ad97a commit 89a0e3c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
17 changes: 10 additions & 7 deletions examples/offline_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
print(f" {key} = {value}")

# Create sampling params
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=8,
ignore_eos=True)

# Create LLM
llm = LLM(
Expand Down Expand Up @@ -74,14 +77,14 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
sys.exit(-1)

for i in range(batch_size):
prompt_token_ids = torch.randint(
llm.llm_engine.model_config.get_vocab_size(),
size=(prompt_len, )).tolist()

llm.llm_engine.add_request(
request_id=f"seq{i}",
prompt=None,
prompt_token_ids=torch.randint(
128, # 128 to skip over special tokens
llm.llm_engine.model_config.get_vocab_size() // 2,
size=(prompt_len, )).tolist(),
sampling_params=sampling_params)
inputs={'prompt_token_ids': prompt_token_ids},
params=sampling_params)

with nm_profile() as prefill_prof:
llm.llm_engine.step() # First step is prefill
Expand Down
3 changes: 2 additions & 1 deletion vllm/profiler/nm_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def _get_kineto_gpu_event(self, node: _ModuleTreeNode):
correlated_kineto_events = self._kineto_event_correlation_map.get(
node.event.correlation_id, [])
iterator = (x for x in correlated_kineto_events
if x.device_type() == DeviceType.CUDA)
if x.device_type() == DeviceType.CUDA
and x.name() == node.event.name)
return next(iterator, None)

def _cumulative_cuda_time(self, node: _ModuleTreeNode):
Expand Down
11 changes: 7 additions & 4 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file has been modified by Neural Magic

import gc
import time
import warnings
Expand Down Expand Up @@ -986,9 +988,13 @@ def vocab_size(self) -> int:
return self.model_config.get_vocab_size()


class CUDAGraphRunner:
# NOTE: this is nn.Module so the profiler can properly capture/group
# kernels calls made within the graph
class CUDAGraphRunner(nn.Module):

def __init__(self, model: nn.Module):
super().__init__()

self.model = model
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
Expand Down Expand Up @@ -1084,9 +1090,6 @@ def forward(
# Return the output tensor.
return self.output_buffers["hidden_states"]

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)


def _get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
Expand Down

0 comments on commit 89a0e3c

Please sign in to comment.