Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

[Bugfix] Update profile example to new add request interface + fix profiler not picking up kernels within cudagraphs #332

Merged
merged 6 commits into from
Jun 26, 2024

Conversation

LucasWilkinson
Copy link
Collaborator

This PR updates examples/offline_profile.py to use the new add_requests interface. In addition, this PR fixes issues with profiler not picking up kernels run from within a cudagraph (i.e. when profiling with --allow-cuda-graphs, there's is too main issues:

  1. Changes from the initial profiler PR (Initial Layerwise Profiler #124) were wiped out by Upstream sync 2024 05 05 #224, namely the changes in model_runner.py converting CUDAGraphRunner to a nn.Module allowing the profiler to pick it up
  2. Many kernels within the graph had the same correlation id so we were always picking the first of potentially many kernels to display, using name in addition to correlation Id sees to resolve this issue but is potentially fragile

Before the PR:

================================================================================
= Decode Summary Table (prompt_len=1, batch_size=1)
================================================================================

name                                                                             | cuda_time_us | pct_cuda_... | invocations    
================================================================================================================================
LogitsProcessor                                                                  |       350.00 |        57.76 |            1.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         3.00 |         0.50 |            1.00
|- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x... |       347.00 |        57.26 |            1.00
Sampler                                                                          |       256.00 |        42.24 |            1.00
|- Memcpy HtoD (Pinned -> Device)                                                |        18.00 |         2.97 |            9.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         4.00 |         0.66 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.83 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        54.00 |         8.91 |            1.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        43.00 |         7.10 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         4.00 |         0.66 |            2.00
|- void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_ke... |        13.00 |         2.15 |            4.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::nat... |        29.00 |         4.79 |            1.00
|- Memcpy DtoH (Device -> Pageable)                                              |        10.00 |         1.65 |            5.00
|- void (anonymous namespace)::elementwise_kernel_with_index<int, at::native:... |         1.00 |         0.17 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         3.00 |         0.50 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<long, at::nati... |        15.00 |         2.48 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.17 |            1.00
|- void at::native::mbtopk::fill<unsigned int, unsigned int>(unsigned int*, u... |         1.00 |         0.17 |            1.00
|- void at::native::mbtopk::radixFindKthValues<float, unsigned int, unsigned ... |        36.00 |         5.94 |            4.00
|- void at::native::mbtopk::computeBlockwiseWithinKCounts<unsigned int>(unsig... |         8.00 |         1.32 |            4.00
|- void at::native::mbtopk::computeBlockwiseKthCounts<unsigned int>(unsigned ... |         1.00 |         0.17 |            1.00
|- void at_cuda_detail::cub::DeviceScanByKeyInitKernel<at_cuda_detail::cub::R... |         2.00 |         0.33 |            2.00
|- void at_cuda_detail::cub::DeviceScanByKeyKernel<at_cuda_detail::cub::Devic... |         4.00 |         0.66 |            2.00
|- void at::native::mbtopk::gatherTopK<float, unsigned int, 1>(at::cuda::deta... |         4.00 |         0.66 |            1.00

After PR

name                                                                             | cuda_time_us | pct_cuda_... | invocations    
================================================================================================================================
CUDAGraphRunner                                                                  |      4238.00 |        84.41 |            1.00
|- Memcpy DtoD (Device -> Device)                                                |         5.00 |         0.10 |            5.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<... |         2.00 |         0.04 |            2.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         2.00 |         0.04 |            1.00
|- void vllm::rms_norm_kernel<c10::BFloat16>(c10::BFloat16*, c10::BFloat16 co... |         4.00 |         0.08 |            1.00
|- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*, c10... |       256.00 |         5.10 |          128.00
|- sm90_xmma_gemm_e4m3bf16_e4m3f32_f32_tn_n_tilesize64x64x128_warpgroupsize1x... |      1440.00 |        28.68 |           96.00
|- void vllm::rotary_embedding_kernel<c10::BFloat16, true>(long const*, c10::... |        96.00 |         1.91 |           32.00
|- void vllm::reshape_and_cache_flash_kernel<c10::BFloat16>(c10::BFloat16 con... |        64.00 |         1.27 |           32.00
|- void flash_fwd_splitkv_kernel<Flash_fwd_kernel_traits<128, 64, 128, 4, fal... |       160.00 |         3.19 |           32.00
|- void flash_fwd_splitkv_combine_kernel<Flash_fwd_kernel_traits<128, 64, 128... |       160.00 |         3.19 |           32.00
|- memcpy32_post                                                                 |        33.00 |         0.66 |           33.00
|- std::enable_if<(((8)>(0)))&&vllm::_typeConvert<c10::BFloat16>::exists, voi... |       128.00 |         2.55 |           64.00
|- sm90_xmma_gemm_e4m3bf16_e4m3f32_f32_tn_n_tilesize64x128x128_warpgroupsize1... |      1664.00 |        33.14 |           32.00
|- void vllm::act_and_mul_kernel<c10::BFloat16, &(c10::BFloat16 vllm::silu_ke... |       224.00 |         4.46 |           32.00
LogitsProcessor                                                                  |       351.00 |         6.99 |            1.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         3.00 |         0.06 |            1.00
|- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x... |       348.00 |         6.93 |            1.00
Sampler                                                                          |       432.00 |         8.60 |            1.00
|- Memcpy HtoD (Pinned -> Device)                                                |        18.00 |         0.36 |            9.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         4.00 |         0.08 |            1.00
|- at::native::(anonymous namespace)::fill_reverse_indices_kernel(long*, int,... |         2.00 |         0.04 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortUpsweepKernel<at_cuda_detail::cub... |        12.00 |         0.24 |            3.00
|- void at_cuda_detail::cub::RadixSortScanBinsKernel<at_cuda_detail::cub::Dev... |         7.00 |         0.14 |            3.00
|- void at_cuda_detail::cub::DeviceRadixSortDownsweepKernel<at_cuda_detail::c... |        46.00 |         0.92 |            3.00
|- Memcpy DtoD (Device -> Device)                                                |         1.00 |         0.02 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.10 |            3.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |         6.00 |         0.12 |            2.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         6.00 |         0.12 |            2.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous n... |         2.00 |         0.04 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<8, c10::BFloat... |        39.00 |         0.78 |            1.00
|- void at_cuda_detail::cub::DeviceScanInitKernel<at_cuda_detail::cub::ScanTi... |         1.00 |         0.02 |            1.00
|- void at_cuda_detail::cub::DeviceScanKernel<at_cuda_detail::cub::DeviceScan... |         6.00 |         0.12 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::FillFunctor<bool>... |         1.00 |         0.02 |            1.00
|- void (anonymous namespace)::elementwise_kernel_with_index<int, at::native:... |         3.00 |         0.06 |            2.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |         5.00 |         0.10 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.10 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        86.00 |         1.71 |            1.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        43.00 |         0.86 |            1.00
|- void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_ke... |        14.00 |         0.28 |            4.00
|- void at::native::(anonymous namespace)::distribution_elementwise_grid_stri... |         3.00 |         0.06 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::BinaryFuncto... |         2.00 |         0.04 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::nat... |        28.00 |         0.56 |            1.00
|- Memcpy DtoH (Device -> Pageable)                                              |        10.00 |         0.20 |            5.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         3.00 |         0.06 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<long, at::nati... |        15.00 |         0.30 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::mbtopk::fill<unsigned int, unsigned int>(unsigned int*, u... |         1.00 |         0.02 |            1.00
|- void at::native::mbtopk::radixFindKthValues<float, unsigned int, unsigned ... |        36.00 |         0.72 |            4.00
|- void at::native::mbtopk::computeBlockwiseWithinKCounts<unsigned int>(unsig... |         8.00 |         0.16 |            4.00
|- void at::native::mbtopk::computeBlockwiseKthCounts<unsigned int>(unsigned ... |         1.00 |         0.02 |            1.00
|- void at_cuda_detail::cub::DeviceScanByKeyInitKernel<at_cuda_detail::cub::R... |         2.00 |         0.04 |            2.00
|- void at_cuda_detail::cub::DeviceScanByKeyKernel<at_cuda_detail::cub::Devic... |         4.00 |         0.08 |            2.00
|- void at::native::mbtopk::gatherTopK<float, unsigned int, 1>(at::cuda::deta... |         4.00 |         0.08 |            1.00

@varun-sundar-rabindranath

Hey @LucasWilkinson
It looks like the CudaGraphRunner decode doesn't report all of the Gemm kernels ? PFA a eager-mode trace,

================================================================================
= Decode Summary Table (prompt_len=16, batch_size=1)
================================================================================

name                                                                             | cuda_time_us | pct_cuda_... | invocations    
================================================================================================================================
LlamaForCausalLM                                                                 |      6027.00 |        90.92 |            1.00
|- LlamaModel                                                                    |      6027.00 |        90.92 |            1.00
|-- VocabParallelEmbedding(weight=bfloat16[128256, 4096])                        |         3.00 |         0.05 |            1.00
|--- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloa... |         3.00 |         0.05 |            1.00
|-- LlamaDecoderLayer                                                            |      6021.00 |        90.83 |           32.00
|--- RMSNorm(weight=bfloat16[4096])                                              |       180.00 |         2.72 |           64.00
|---- void vllm::rms_norm_kernel<c10::BFloat16>(c10::BFloat16*, c10::BFloat16... |         4.00 |         0.06 |            1.00
|---- std::enable_if<(((8)>(0)))&&vllm::_typeConvert<c10::BFloat16>::exists, ... |       176.00 |         2.66 |           63.00
|--- LlamaAttention                                                              |      1616.00 |        24.38 |           32.00
|---- QKVParallelLinear(weight=bfloat16[6144, 4096])                             |       682.00 |        10.29 |           32.00
|----- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x64x64_warpgroupsiz... |       682.00 |        10.29 |           32.00
|---- RotaryEmbedding                                                            |        99.00 |         1.49 |           32.00
|----- void vllm::rotary_embedding_kernel<c10::BFloat16, true>(long const*, c... |        99.00 |         1.49 |           32.00
|---- Attention                                                                  |       288.00 |         4.34 |           32.00
|----- void vllm::reshape_and_cache_flash_kernel<c10::BFloat16>(c10::BFloat16... |        64.00 |         0.97 |           32.00
|----- void flash_fwd_splitkv_kernel<Flash_fwd_kernel_traits<128, 64, 128, 4,... |       192.00 |         2.90 |           32.00
|----- Memcpy DtoD (Device -> Device)                                            |        32.00 |         0.48 |           32.00
|---- RowParallelLinear(weight=bfloat16[4096, 4096])                             |       547.00 |         8.25 |           32.00
|----- void gemv2T_kernel_val<int, int, __nv_bfloat16, __nv_bfloat16, __nv_bf... |       547.00 |         8.25 |           32.00
|--- LlamaMLP                                                                    |      4225.00 |        63.74 |           32.00
|---- MergedColumnParallelLinear(weight=bfloat16[28672, 4096])                   |      2577.00 |        38.87 |           32.00
|----- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsi... |      2577.00 |        38.87 |           32.00
|---- SiluAndMul                                                                 |       258.00 |         3.89 |           32.00
|----- void vllm::act_and_mul_kernel<c10::BFloat16, &(c10::BFloat16 vllm::sil... |       258.00 |         3.89 |           32.00
|---- RowParallelLinear(weight=bfloat16[4096, 14336])                            |      1390.00 |        20.97 |           32.00
|----- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x64x64_warpgroupsiz... |      1390.00 |        20.97 |           32.00
|-- RMSNorm(weight=bfloat16[4096])                                               |         3.00 |         0.05 |            1.00
|--- std::enable_if<(((8)>(0)))&&vllm::_typeConvert<c10::BFloat16>::exists, v... |         3.00 |         0.05 |            1.00
LogitsProcessor                                                                  |       347.00 |         5.23 |            1.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         2.00 |         0.03 |            1.00
|- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x... |       345.00 |         5.20 |            1.00
Sampler                                                                          |       255.00 |         3.85 |            1.00
|- Memcpy HtoD (Pinned -> Device)                                                |        18.00 |         0.27 |            9.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         4.00 |         0.06 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.08 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        54.00 |         0.81 |            1.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        44.00 |         0.66 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         3.00 |         0.05 |            2.00
|- void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_ke... |        11.00 |         0.17 |            4.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::nat... |        28.00 |         0.42 |            1.00
|- Memcpy DtoH (Device -> Pageable)                                              |        10.00 |         0.15 |            5.00
|- void (anonymous namespace)::elementwise_kernel_with_index<int, at::native:... |         1.00 |         0.02 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         3.00 |         0.05 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<long, at::nati... |        16.00 |         0.24 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::mbtopk::fill<unsigned int, unsigned int>(unsigned int*, u... |         1.00 |         0.02 |            1.00
|- void at::native::mbtopk::radixFindKthValues<float, unsigned int, unsigned ... |        36.00 |         0.54 |            4.00
|- void at::native::mbtopk::computeBlockwiseWithinKCounts<unsigned int>(unsig... |         8.00 |         0.12 |            4.00
|- void at::native::mbtopk::computeBlockwiseKthCounts<unsigned int>(unsigned ... |         2.00 |         0.03 |            1.00
|- void at_cuda_detail::cub::DeviceScanByKeyInitKernel<at_cuda_detail::cub::R... |         2.00 |         0.03 |            2.00
|- void at_cuda_detail::cub::DeviceScanByKeyKernel<at_cuda_detail::cub::Devic... |         4.00 |         0.06 |            2.00
|- void at::native::mbtopk::gatherTopK<float, unsigned int, 1>(at::cuda::deta... |         4.00 |         0.06 |            1.00

I see 4 Gemms unders LlamaForCausalLM but I see only 2 under CudaGraphRunner in the PR description.

I am okay with landing this as-is and debugging this further.

@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Jun 25, 2024

In eager mode we can uniquely identify GEMMs by using the calling module, when the CUDAGraphRunner is the module if the two GEMMs that were unique in eager mode (due to different modules) but called the same kernels e.g. from the trace you shared:

|---- QKVParallelLinear(weight=bfloat16[6144, 4096])                             |       682.00 |        10.29 |           32.00
|----- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x64x64_warpgroupsiz... |       682.00 |        10.29 |           32.00
....
|---- RowParallelLinear(weight=bfloat16[4096, 14336])                            |      1390.00 |        20.97 |           32.00
|----- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x64x64_warpgroupsiz... |      1390.00 |        20.97 |           32.00

under Cuda-graphs this will be lumped into one aggregate:

|----- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x64x64_warpgroupsiz... 

line. You can see this by the number of invocations, for some of the GEMMs they are double.

Unfortunately I have not figured out a way to correlate kernel calls from within a cudagraph back to a specific module, so its hard to separate these gemms out by layers when using cudagraphs.

@tlrmchlsmth
Copy link
Member

Unfortunately I have not figured out a way to correlate kernel calls from within a cudagraph back to a specific module, so its hard to separate these gemms out by layers when using cudagraphs.

This is OK for what we need to do right now, fortunately

Comment on lines 78 to 79
128, # 128 to skip over special tokens
llm.llm_engine.model_config.get_vocab_size() // 2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a standard less hacky way to do this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I know of, but it looks like there ignore_eos option in SamplingParams which can be set so we don't have to worry about special tokens (basically we were getting premature EOSs ending the profiling and restricting the vocab to not have special tokens fixed it, but ignore_eos is a more robust way of solving this)

vllm/worker/model_runner.py Show resolved Hide resolved
@LucasWilkinson LucasWilkinson merged commit 89a0e3c into main Jun 26, 2024
38 checks passed
@LucasWilkinson LucasWilkinson deleted the lwilkinson/profiler-bugfixes branch June 26, 2024 15:32
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants