Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ROCm] Replace layer_norm_grad_input_kernel with cuComputeGradInput f…
…or ROCm (pytorch#87726) We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of pytorch#68238 (comment) on AMD GPUs. This PR is to replace layer_norm_grad_input_kernel with the Apex cuComputeGradInput kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than or equal to `32768` on AMD GPUs. Some of the code changes in LayerNormBackwardKernelImplInternal are from another PR: pytorch#87635 We used the same benchmark script in the previous PR and tested the optimized kernel with various input shapes on AMD MI100 GPU. **At [the previous PR](pytorch#87635 <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38589 | 0.92603 | 0.38367 | 1.15148 50176 | 384 | 0.38719 | 0.91579 | 0.37815 | 1.13761 200704 | 192 | 0.99787 | 2.39954 | 0.98996 | 2.54284 802816 | 64 | 3.66525 | 7.96952 | 3.61293 | 7.69946 200 | 256 | 0.06578 | 0.34613 | 0.06966 | 0.35449 1000 | 256 | 0.07837 | 0.37631 | 0.07725 | 0.37758 6000 | 256 | 0.09318 | 0.3788 | 0.09202 | 0.37989 6272 | 256 | 0.08694 | 0.36267 | 0.08703 | 0.3615 200 | 512 | 0.06975 | 0.34506 | 0.06973 | 0.34208 1000 | 512 | 0.07012 | 0.36363 | 0.07307 | 0.36741 6000 | 512 | 0.09725 | 0.36251 | 0.09908 | 0.37078 6272 | 512 | 0.09899 | 0.36519 | 0.10068 | 0.37514 200 | 1024 | 0.07188 | 0.33896 | 0.0712 | 0.34683 1000 | 1024 | 0.07357 | 0.3625 | 0.0734 | 0.3598 6000 | 1024 | 0.12642 | 0.38949 | 0.12973 | 0.5035 6272 | 1024 | 0.12901 | 0.40759 | 0.13609 | 0.51871 200 | 1536 | 0.06998 | 0.34782 | 0.07419 | 0.3514 1000 | 1536 | 0.07987 | 0.37915 | 0.07888 | 0.37264 6000 | 1536 | 0.15401 | 0.47524 | 0.15416 | 0.68609 6272 | 1536 | 0.15286 | 0.48843 | 0.17681 | 0.72997 200 | 2048 | 0.07054 | 0.34791 | 0.07289 | 0.35138 1000 | 2048 | 0.07767 | 0.37954 | 0.08554 | 0.37464 6000 | 2048 | 0.18744 | 0.5811 | 0.25004 | 0.93338 6272 | 2048 | 0.20037 | 0.63398 | 0.26918 | 0.97018 200 | 3072 | 0.07687 | 0.36739 | 0.08917 | 0.37845 1000 | 3072 | 0.09323 | 0.38901 | 0.09739 | 0.39823 6000 | 3072 | 0.24314 | 0.89029 | 0.38093 | 1.30719 6272 | 3072 | 0.26079 | 0.92023 | 0.38352 | 1.51012 128 | 2097152 | 6.17775 | 23.876 | 10.27952 | 30.10848 256 | 1048576 | 4.51855 | 19.47637 | 10.07609 | 29.42678 512 | 524288 | 4.13615 | 18.80888 | 10.07853 | 32.29804 1024 | 262144 | 4.47397 | 17.88388 | 9.50367 | 31.15699 2048 | 131072 | 4.2458 | 16.70852 | 9.17979 | 30.51708 4096 | 65536 | 4.24412 | 16.43098 | 8.97651 | 30.1617 8192 | 32768 | 4.24556 | 16.09038 | 8.77001 | 30.3643 16384 | 16384 | 4.14642 | 15.80355 | 8.82402 | 30.35291 32768 | 8192 | 4.12599 | 15.68897 | 8.82605 | 30.43423 </body> </html> ---- **At this PR:** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} .xl66 {background:yellow; mso-pattern:black none;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38667 | 0.84133 | 0.37916 | 1.01222 50176 | 384 | 0.3814 | 0.87266 | 0.37858 | 1.04399 200704 | 192 | 0.99902 | 2.14386 | 0.98973 | 2.33265 802816 | 64 | 3.66578 | 6.85376 | 3.6092 | 7.00331 200 | 256 | 0.06607 | 0.34176 | 0.07009 | 0.34548 1000 | 256 | 0.06947 | 0.36461 | 0.07902 | 0.37851 6000 | 256 | 0.09319 | 0.37432 | 0.09342 | 0.36927 6272 | 256 | 0.09544 | 0.37565 | 0.09476 | 0.37377 200 | 512 | 0.07935 | 0.364 | 0.07891 | 0.36894 1000 | 512 | 0.07676 | 0.37552 | 0.07957 | 0.37564 6000 | 512 | 0.10472 | 0.37504 | 0.1051 | 0.38782 6272 | 512 | 0.1069 | 0.36662 | 0.10062 | 0.38506 200 | 1024 | 0.07793 | 0.36561 | 0.08023 | 0.35019 1000 | 1024 | 0.07426 | 0.36729 | 0.07345 | 0.35851 6000 | 1024 | 0.12729 | 0.39219 | 0.12974 | 0.51526 6272 | 1024 | 0.13622 | 0.41627 | 0.14252 | 0.52926 200 | 1536 | 0.07615 | 0.36621 | 0.0797 | 0.3695 1000 | 1536 | 0.08327 | 0.38174 | 0.07938 | 0.37573 6000 | 1536 | 0.14894 | 0.46197 | 0.15268 | 0.63814 6272 | 1536 | 0.15368 | 0.48818 | 0.16309 | 0.71441 200 | 2048 | 0.06935 | 0.36691 | 0.07258 | 0.35548 1000 | 2048 | 0.07738 | 0.36388 | 0.08036 | 0.36452 6000 | 2048 | 0.18757 | 0.58573 | 0.23701 | 0.92915 6272 | 2048 | 0.1938 | 0.61628 | 0.26475 | 0.96896 200 | 3072 | 0.07884 | 0.3673 | 0.07724 | 0.37869 1000 | 3072 | 0.09342 | 0.38193 | 0.09822 | 0.38646 6000 | 3072 | 0.24452 | 0.86776 | 0.38251 | 1.3036 6272 | 3072 | 0.25971 | 0.91053 | 0.38744 | 1.39039 128 | 2097152 | 6.06752 | 23.26379 | 9.87466 | 29.81851 256 | 1048576 | 4.50336 | 19.4614 | 10.11239 | 29.25554 512 | 524288 | 4.12649 | 18.72831 | 10.054 | 32.26784 1024 | 262144 | 4.40855 | 17.77993 | 9.38856 | 31.18679 2048 | 131072 | 4.18716 | 16.74615 | 9.14487 | 30.24603 4096 | 65536 | 4.17374 | 16.34444 | 8.94894 | 30.0326 8192 | 32768 | 4.19095 | 16.05751 | 8.70358 | 30.14669 16384 | 16384 | 4.15404 | 15.83771 | 8.80042 | 30.5022 32768 | 8192 | 4.12515 | 15.5657 | 8.66138 | 28.87386 </body> </html> --- **Performance Improvement (%)** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl65 {color:windowtext;} .xl66 {mso-number-format:"0\.000";} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 9.147 | 12.094 50176 | 384 | 4.710 | 8.230 200704 | 192 | 10.655 | 8.266 802816 | 64 | 14.000 | 9.042 200 | 256 | 1.263 | 2.542 1000 | 256 | 3.109 | -0.246 6000 | 256 | 1.183 | 2.796 6272 | 256 | -3.579 | -3.394 200 | 512 | -5.489 | -7.852 1000 | 512 | -3.270 | -2.240 6000 | 512 | -3.456 | -4.596 6272 | 512 | -0.392 | -2.644 200 | 1024 | -7.862 | -0.969 1000 | 1024 | -1.321 | 0.359 6000 | 1024 | -0.693 | -2.336 6272 | 1024 | -2.130 | -2.034 200 | 1536 | -5.287 | -5.151 1000 | 1536 | -0.683 | -0.829 6000 | 1536 | 2.792 | 6.989 6272 | 1536 | 0.051 | 2.132 200 | 2048 | -5.461 | -1.167 1000 | 2048 | 4.126 | 2.701 6000 | 2048 | -0.797 | 0.453 6272 | 2048 | 2.792 | 0.126 200 | 3072 | 0.024 | -0.063 1000 | 3072 | 1.820 | 2.956 6000 | 3072 | 2.531 | 0.275 6272 | 3072 | 1.054 | 7.929 128 | 2097152 | 2.564 | 0.963 256 | 1048576 | 0.077 | 0.582 512 | 524288 | 0.428 | 0.094 1024 | 262144 | 0.581 | -0.096 2048 | 131072 | -0.225 | 0.888 4096 | 65536 | 0.527 | 0.428 8192 | 32768 | 0.204 | 0.717 16384 | 16384 | -0.216 | -0.492 32768 | 8192 | 0.786 | 5.127 </body> </html> CC: @jeffdaily Pull Request resolved: pytorch#87726 Approved by: https://github.com/ngimel
- Loading branch information