Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Initial Layerwise Profiler #124

Merged
merged 28 commits into from
Mar 26, 2024
Merged

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Mar 14, 2024

SUMMARY:

Initial layerwise profiler leveraging the kineto base PyTorch profiler.

NOTE: we run in eager mode by default so that the stack-trace/event-tree contains nn.Module, otherwise if cuda-graphs is used all the kernels will be under the CUDAGraphRunner module (this is converted to an nn.Module in this PR so that the _build_module_tree code puts all those kernels under a CUDAGraphRunner)

NOTE: vllm kernels like vllm::reshape_and_cache_kernel have no trace or shape informat because they are not registered as a TorchOp (i.e. using TORCH_LIBRARY, they instead just a raw PYBIND11_MODULE module)

Example on how to use visualization:

pip install -r requirements-dev.txt

python examples/offline_profile.py --model nm-testing/OpenHermes-2.5-Mistral-7B-pruned50 --batch-size 4 --prompt-len 512 --json openhermes7b-dense

### For Breakdown Graphs
# For module level breakdown
python neuralmagic/tools/profiler/visualize_trace.py --json-trace openhermes7b-dense.json --output profile_dense.pdf
# For kernel level breakdown
python neuralmagic/tools/profiler/visualize_trace.py --json-trace openhermes7b-dense.json --output profile_dense_kernels.pdf --level kernel

### For table printing
# Decode Summary Table
python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase decode
# Prefill Summary Table
python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase prefill
# Decode Model Layerwise Table
python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase decode --table model
# Prefill Model Layerwise Table
python neuralmagic/tools/profiler/print_table.py --json-trace openhermes7b-dense.json --phase prefill --table model

Example Output: profile-example-output.txt

TEST PLAN:

  • GHA

@LucasWilkinson LucasWilkinson marked this pull request as ready for review March 20, 2024 14:06
@LucasWilkinson LucasWilkinson changed the title [WIP] Initial Layerwise Profiler Initial Layerwise Profiler Mar 21, 2024
examples/offline_profile.py Show resolved Hide resolved
examples/offline_profile.py Show resolved Hide resolved
examples/offline_profile.py Outdated Show resolved Hide resolved
requirements-dev.txt Outdated Show resolved Hide resolved
sampling_params=sampling_params)

with nm_profile() as prefill_prof:
llm.llm_engine.step() # First step is prefill
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always true?

Does the fact that max_num_batched_tokens < batch_size * prompt_len guarantee this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as of now yes, but its brittle in the sense that logic could change and this would fall out of sync

help=f"Maximum length of a sequence (including prompt and output), "
f"default={MAX_SEQ_LEN_DEFAULT}")
parser.add_argument(
"--max_num_batched_tokens",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be set to batch_size * prompt_len + eps?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally I like forcing the user to set this (in this case a dev since this tooling is targeting devs) since it means less surprises when the number of GPU blocks decreases and will make the user think twice before setting a very large batch + prompt-len (and/or making it less surprising when we go OOM). But I can be convinced otherwise.

@LucasWilkinson LucasWilkinson merged commit 7ae99c2 into main Mar 26, 2024
2 checks passed
@LucasWilkinson LucasWilkinson deleted the lwilkinson/layerwise-profiling branch March 26, 2024 15:08
LucasWilkinson added a commit that referenced this pull request Jun 26, 2024
…ofiler not picking up kernels within cudagraphs (#332)

This PR updates `examples/offline_profile.py` to use the new
`add_requests` interface. In addition, this PR fixes issues with
profiler not picking up kernels run from within a cudagraph (i.e. when
profiling with `--allow-cuda-graphs`, there's is too main issues:

1) Changes from the initial profiler PR
(#124) were wiped out by
#224, namely the changes in
`model_runner.py` converting `CUDAGraphRunner` to a `nn.Module` allowing
the profiler to pick it up
2) Many kernels within the graph had the same correlation id so we were
always picking the first of potentially many kernels to display, using
name in addition to correlation Id sees to resolve this issue but is
potentially fragile


Before the PR:
```
================================================================================
= Decode Summary Table (prompt_len=1, batch_size=1)
================================================================================

name                                                                             | cuda_time_us | pct_cuda_... | invocations    
================================================================================================================================
LogitsProcessor                                                                  |       350.00 |        57.76 |            1.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         3.00 |         0.50 |            1.00
|- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x... |       347.00 |        57.26 |            1.00
Sampler                                                                          |       256.00 |        42.24 |            1.00
|- Memcpy HtoD (Pinned -> Device)                                                |        18.00 |         2.97 |            9.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         4.00 |         0.66 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.83 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        54.00 |         8.91 |            1.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        43.00 |         7.10 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         4.00 |         0.66 |            2.00
|- void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_ke... |        13.00 |         2.15 |            4.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::nat... |        29.00 |         4.79 |            1.00
|- Memcpy DtoH (Device -> Pageable)                                              |        10.00 |         1.65 |            5.00
|- void (anonymous namespace)::elementwise_kernel_with_index<int, at::native:... |         1.00 |         0.17 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         3.00 |         0.50 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<long, at::nati... |        15.00 |         2.48 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.17 |            1.00
|- void at::native::mbtopk::fill<unsigned int, unsigned int>(unsigned int*, u... |         1.00 |         0.17 |            1.00
|- void at::native::mbtopk::radixFindKthValues<float, unsigned int, unsigned ... |        36.00 |         5.94 |            4.00
|- void at::native::mbtopk::computeBlockwiseWithinKCounts<unsigned int>(unsig... |         8.00 |         1.32 |            4.00
|- void at::native::mbtopk::computeBlockwiseKthCounts<unsigned int>(unsigned ... |         1.00 |         0.17 |            1.00
|- void at_cuda_detail::cub::DeviceScanByKeyInitKernel<at_cuda_detail::cub::R... |         2.00 |         0.33 |            2.00
|- void at_cuda_detail::cub::DeviceScanByKeyKernel<at_cuda_detail::cub::Devic... |         4.00 |         0.66 |            2.00
|- void at::native::mbtopk::gatherTopK<float, unsigned int, 1>(at::cuda::deta... |         4.00 |         0.66 |            1.00
```

After PR
```
name                                                                             | cuda_time_us | pct_cuda_... | invocations    
================================================================================================================================
CUDAGraphRunner                                                                  |      4238.00 |        84.41 |            1.00
|- Memcpy DtoD (Device -> Device)                                                |         5.00 |         0.10 |            5.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<... |         2.00 |         0.04 |            2.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         2.00 |         0.04 |            1.00
|- void vllm::rms_norm_kernel<c10::BFloat16>(c10::BFloat16*, c10::BFloat16 co... |         4.00 |         0.08 |            1.00
|- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*, c10... |       256.00 |         5.10 |          128.00
|- sm90_xmma_gemm_e4m3bf16_e4m3f32_f32_tn_n_tilesize64x64x128_warpgroupsize1x... |      1440.00 |        28.68 |           96.00
|- void vllm::rotary_embedding_kernel<c10::BFloat16, true>(long const*, c10::... |        96.00 |         1.91 |           32.00
|- void vllm::reshape_and_cache_flash_kernel<c10::BFloat16>(c10::BFloat16 con... |        64.00 |         1.27 |           32.00
|- void flash_fwd_splitkv_kernel<Flash_fwd_kernel_traits<128, 64, 128, 4, fal... |       160.00 |         3.19 |           32.00
|- void flash_fwd_splitkv_combine_kernel<Flash_fwd_kernel_traits<128, 64, 128... |       160.00 |         3.19 |           32.00
|- memcpy32_post                                                                 |        33.00 |         0.66 |           33.00
|- std::enable_if<(((8)>(0)))&&vllm::_typeConvert<c10::BFloat16>::exists, voi... |       128.00 |         2.55 |           64.00
|- sm90_xmma_gemm_e4m3bf16_e4m3f32_f32_tn_n_tilesize64x128x128_warpgroupsize1... |      1664.00 |        33.14 |           32.00
|- void vllm::act_and_mul_kernel<c10::BFloat16, &(c10::BFloat16 vllm::silu_ke... |       224.00 |         4.46 |           32.00
LogitsProcessor                                                                  |       351.00 |         6.99 |            1.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         3.00 |         0.06 |            1.00
|- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x... |       348.00 |         6.93 |            1.00
Sampler                                                                          |       432.00 |         8.60 |            1.00
|- Memcpy HtoD (Pinned -> Device)                                                |        18.00 |         0.36 |            9.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         4.00 |         0.08 |            1.00
|- at::native::(anonymous namespace)::fill_reverse_indices_kernel(long*, int,... |         2.00 |         0.04 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortUpsweepKernel<at_cuda_detail::cub... |        12.00 |         0.24 |            3.00
|- void at_cuda_detail::cub::RadixSortScanBinsKernel<at_cuda_detail::cub::Dev... |         7.00 |         0.14 |            3.00
|- void at_cuda_detail::cub::DeviceRadixSortDownsweepKernel<at_cuda_detail::c... |        46.00 |         0.92 |            3.00
|- Memcpy DtoD (Device -> Device)                                                |         1.00 |         0.02 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.10 |            3.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |         6.00 |         0.12 |            2.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         6.00 |         0.12 |            2.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous n... |         2.00 |         0.04 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<8, c10::BFloat... |        39.00 |         0.78 |            1.00
|- void at_cuda_detail::cub::DeviceScanInitKernel<at_cuda_detail::cub::ScanTi... |         1.00 |         0.02 |            1.00
|- void at_cuda_detail::cub::DeviceScanKernel<at_cuda_detail::cub::DeviceScan... |         6.00 |         0.12 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::FillFunctor<bool>... |         1.00 |         0.02 |            1.00
|- void (anonymous namespace)::elementwise_kernel_with_index<int, at::native:... |         3.00 |         0.06 |            2.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |         5.00 |         0.10 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         5.00 |         0.10 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        86.00 |         1.71 |            1.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        43.00 |         0.86 |            1.00
|- void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_ke... |        14.00 |         0.28 |            4.00
|- void at::native::(anonymous namespace)::distribution_elementwise_grid_stri... |         3.00 |         0.06 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::BinaryFuncto... |         2.00 |         0.04 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::nat... |        28.00 |         0.56 |            1.00
|- Memcpy DtoH (Device -> Pageable)                                              |        10.00 |         0.20 |            5.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         3.00 |         0.06 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<long, at::nati... |        15.00 |         0.30 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.00 |         0.02 |            1.00
|- void at::native::mbtopk::fill<unsigned int, unsigned int>(unsigned int*, u... |         1.00 |         0.02 |            1.00
|- void at::native::mbtopk::radixFindKthValues<float, unsigned int, unsigned ... |        36.00 |         0.72 |            4.00
|- void at::native::mbtopk::computeBlockwiseWithinKCounts<unsigned int>(unsig... |         8.00 |         0.16 |            4.00
|- void at::native::mbtopk::computeBlockwiseKthCounts<unsigned int>(unsigned ... |         1.00 |         0.02 |            1.00
|- void at_cuda_detail::cub::DeviceScanByKeyInitKernel<at_cuda_detail::cub::R... |         2.00 |         0.04 |            2.00
|- void at_cuda_detail::cub::DeviceScanByKeyKernel<at_cuda_detail::cub::Devic... |         4.00 |         0.08 |            2.00
|- void at::native::mbtopk::gatherTopK<float, unsigned int, 1>(at::cuda::deta... |         4.00 |         0.08 |            1.00
```

---------

Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants