-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUTLASS] Add parallel split-k support to wgrad #10185
Conversation
If you want to investigate accuracy issue, i suggest you compare both cutlass and cudnn with a naive fp64 or fp32 version. |
commit 60b73a91b79d644d8c95f682eedaf47a89abba0d Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Feb 8 10:43:11 2022 +0900 pylint commit ae2e718 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 14:51:52 2022 +0900 Add split-k support for wgrad commit 43820d5 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 10:07:34 2022 +0900 fix and add doc commit 446a95b Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 09:48:38 2022 +0900 dw conv2d properly supported for wgrad commit adc4e22 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sat Feb 5 16:32:42 2022 +0900 fix overwriting template commit 040eab0 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sat Feb 5 16:06:27 2022 +0900 black commit e5a07c2 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sat Feb 5 16:03:10 2022 +0900 add reduction in profiler commit be89334 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sat Feb 5 06:58:03 2022 +0900 adding split k reduction to conv2d profiler commit ae09b0f Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 11:52:59 2022 +0900 fixed conv2d_backward_weight typerel for dw conv2d commit 16fe531 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 12:59:22 2022 +0900 wip commit 2167c25 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 04:22:19 2022 +0900 fix conv2d type rel for depth wise and grouped conv2d commit 14b12e5 Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 05:01:03 2022 +0900 remove split_k.py commit b141271 Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 04:48:21 2022 +0900 workaround for invalid split_k_slice commit 6e4c7e1 Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 02:43:58 2022 +0900 support split k in profiler commit 2eb1cf4 Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 02:31:03 2022 +0900 improvement commit 0bce8f3 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 18:20:12 2022 +0900 fixed for fp16 output commit 30df1bd Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 17:50:33 2022 +0900 fp32 output works commit 7a51995 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 14:30:22 2022 +0900 fix commit 4a383e2 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 14:05:24 2022 +0900 update c++ codegen commit 6206e38 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 13:46:05 2022 +0900 wip commit 0ece49b Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 03:05:21 2022 +0900 wip commit 08a6147 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 13:10:21 2022 +0900 test worked with fp32 output commit 084d5c4 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 12:35:18 2022 +0900 fix compile error for fprop commit 31f2543 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 12:18:06 2022 +0900 compiled commit c2098e7 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 11:11:43 2022 +0900 wip commit a145850 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 14:46:16 2022 +0900 fixed for sm75 commit 6151506 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 14:32:46 2022 +0900 all tests work commit 041c094 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 14:19:09 2022 +0900 dw conv2d properly supported for wgrad commit 2191918 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 09:14:05 2022 +0900 wgrad tests now work under pytest commit 78f76df Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 07:31:54 2022 +0900 run black commit 0a82149 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 06:12:39 2022 +0900 [CUTLASS] Add wgrad support (without split-k)
5ec73cb
to
d11f9bc
Compare
cc @mbs-octoml interesting example for perf work |
Hi Masa, This is amazing progress. Some questions on the known issues:
|
Hi Manish,
The benchmark result I linked above show accuracy difference in the last two columns. Most workload have some differences, except for some deeper layers in batch = 8 which showed exact match. It seems deeper layers, those having small spatial size and large channels, have generally less accuracy problems. The differences become much bigger for batch = 256. So it kind of works but not quite, it is very hard to debug. The profiler in cutlass doesn't report any accuracy problem, which is another mystery. It could be TVM's use of cuDNN wgrad having some issues.
The issue is memory reuse across multiple calls. The way we integrate cuDNN and cutlass are significantly different. I tried to apply a similar memory management strategy we use for cuDNN to the JIT-generated cutlass, but as I said above I'm having strange issues.
Yes, I haven't grokked your note in that thread. I just tried a dumb strategy in my benchmark and it already shows good performance. I didn't pursue perf improvement further, since the accuracy problem was more concerning. |
On accuracy, floating point additions are not associative. The change the order can change the result. Parallel reduction does change the order of accumulation over GEMM-K (NPQ). Thus, some change between runs is expected. I don't have a guidance on what threshold to set in checking relative error. I would take Haicheng's suggestions here and follow:
CUTLASS profiler uses integer input to initialize tensors and matrices. This is to make the error checking easier. You can also use the CUTLASS profiler approach to make sure there are no functional error, i.e., try the operation on integer input. |
Actually, accuracy difference was there even before I added parallel split-k to wgrad. And that the result got closer to cuDNN after adding split-k. So I believe the issue is not in parallel reduction, there is something off elsewhere. I have seen some workload where cuDNN uses cutlass's wgrad and reduction kernel, even in that case there was difference. Probably I should look at how TVM is using cuDNN wgrad first. I haven't applied fp32 wgrad on large inputs, for small ones we used in the unit test, the result looked good. We also have an option of comparing against TVM native results, which I only looked briefly.
That's very interesting... I didn't know that. I can definitely try, thanks. |
* [CUTLASS] Add split-k support to wgrad commit 60b73a91b79d644d8c95f682eedaf47a89abba0d Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Feb 8 10:43:11 2022 +0900 pylint commit ae2e718 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 14:51:52 2022 +0900 Add split-k support for wgrad commit 43820d5 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 10:07:34 2022 +0900 fix and add doc commit 446a95b Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 09:48:38 2022 +0900 dw conv2d properly supported for wgrad commit adc4e22 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sat Feb 5 16:32:42 2022 +0900 fix overwriting template commit 040eab0 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sat Feb 5 16:06:27 2022 +0900 black commit e5a07c2 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sat Feb 5 16:03:10 2022 +0900 add reduction in profiler commit be89334 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sat Feb 5 06:58:03 2022 +0900 adding split k reduction to conv2d profiler commit ae09b0f Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 11:52:59 2022 +0900 fixed conv2d_backward_weight typerel for dw conv2d commit 16fe531 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 12:59:22 2022 +0900 wip commit 2167c25 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 04:22:19 2022 +0900 fix conv2d type rel for depth wise and grouped conv2d commit 14b12e5 Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 05:01:03 2022 +0900 remove split_k.py commit b141271 Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 04:48:21 2022 +0900 workaround for invalid split_k_slice commit 6e4c7e1 Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 02:43:58 2022 +0900 support split k in profiler commit 2eb1cf4 Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Feb 4 02:31:03 2022 +0900 improvement commit 0bce8f3 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 18:20:12 2022 +0900 fixed for fp16 output commit 30df1bd Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 17:50:33 2022 +0900 fp32 output works commit 7a51995 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 14:30:22 2022 +0900 fix commit 4a383e2 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 14:05:24 2022 +0900 update c++ codegen commit 6206e38 Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 13:46:05 2022 +0900 wip commit 0ece49b Author: Masahiro Masuda <masahi129@gmail.com> Date: Thu Feb 3 03:05:21 2022 +0900 wip commit 08a6147 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 13:10:21 2022 +0900 test worked with fp32 output commit 084d5c4 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 12:35:18 2022 +0900 fix compile error for fprop commit 31f2543 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 12:18:06 2022 +0900 compiled commit c2098e7 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 11:11:43 2022 +0900 wip commit a145850 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 14:46:16 2022 +0900 fixed for sm75 commit 6151506 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 14:32:46 2022 +0900 all tests work commit 041c094 Author: Masahiro Masuda <masahi129@gmail.com> Date: Sun Feb 6 14:19:09 2022 +0900 dw conv2d properly supported for wgrad commit 2191918 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 09:14:05 2022 +0900 wgrad tests now work under pytest commit 78f76df Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 07:31:54 2022 +0900 run black commit 0a82149 Author: Masahiro Masuda <masahi129@gmail.com> Date: Wed Feb 2 06:12:39 2022 +0900 [CUTLASS] Add wgrad support (without split-k) * pylint * add more doc * more doc clarification
Building on #10177, this adds parallel split-k support to wgrad.
@comaniac @Laurawly @junrushao1994 @vinx13 @YuchenJin @hwu36 @manishucsd
Split-k is described in https://github.com/NVIDIA/cutlass/blob/master/media/docs/efficient_gemm.md#parallelized-reductions.
This is my first experience using split-k in cutlass or any other API. Wgrad is particularly interesting for split-k since the implicit gemm K dimension is really large in wgrad (
N * P * Q
whereP
andQ
are the output H and W). Without split-k, wgrad on large spatial inputs is extremely slow.For now, I'm not trying anything smart to pick the split-k parameter, instead we ask users to provide possible candidates. I tuned over
[1, 4, 8, 16, 32, 64]
below and that already showed excellent performance. The benchmark code is here.Benchmark result against cuDNN. Note that currently there are non-trivial difference in cuDNN and TVM + cutlass outputs, especially for the larger batch size. I didn't find anything obviously wrong in the generated code and I gave up fixing accuracy difference at some point. Also note that difference is not due to parallel-split-k, even in a normal case the results were different (and actually improved after split-k lol).
The result showed that cutlass winning across the board (
Profiler time
vscuDNN
columns, but again, the results do not match exactly). However, there is a serious problem when cutlass wgrad + split-k kernels are called from TVM (TVM + CUTLASS
column): Split-k requires large workspace, and the space requirement grows linearly withsplit-k-slices
parameter. Right now we naively allocate the workspace on every cutlass kernel call on each run, while for cuDNN we have a simple workspace memory reuse mechanism implemented in (together with a thread local storage)tvm/src/runtime/contrib/cudnn/cudnn_utils.cc
Lines 153 to 161 in 211291f
I attempted adding a simple workspace memory management in https://github.com/masahi/tvm/compare/cutlass-split-k...masahi:cutlass-workspace?expand=1, it kind of works in terms of the expected perf improvement. However, I'm getting segfault and other strange issues. I'm a bit confused as to what the right behavior should be for a thread local memory manager in the context of JIT- generated and compiled multiple translation units. Let me know if you have any thoughts on this issue.
Known issues and TODO