Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize layernorm forward #218

Merged
merged 4 commits into from
Sep 20, 2024
Merged

Conversation

iclementine
Copy link
Collaborator

@iclementine iclementine commented Sep 19, 2024

PR Category

Operator

Type of Change

Performance Optimization

Description

Optimize forward pass of layernorm, now we use one of the three kernels dependending on reduction size

  1. Persistent multiline kernel: when reduction size <=128;
  2. Persistent kernel: when 128 < reduction size <=4096. It uses 1d tile and saves some indexing & masking;
  3. Loop Kernel: when reduction size > 4096.
    It uses a variant of welford algorithm for computing variance, which saves an extra loading of the input.
    It also uses other tricks like reversing the second loop to increase L2 hit rate when reduction size is not too large;
    Evicting more eagerly in the second loop;
    Specializing the last iteration to avoid masking on other iteration.

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Here are the benchmark results on RTX-3090.

before

cpu mode timing

test_reduction_perf.py Operator layernorm Performance Test (torch.float16)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024             0.0102525            0.0792787           0.129
6144             0.0539834            0.0644283           0.838
11264            0.0992241            0.0794113            1.25
16384              0.14798             0.135344            1.09
21504              0.19665             0.194108            1.01
26624             0.244713             0.254988            0.96
31744             0.285166                0.307           0.929
36864             0.338827             0.357343           0.948
41984             0.387447             0.407603           0.951
47104             0.428396             0.458027           0.935
52224             0.475049             0.507055           0.937
57344             0.518136             0.556219           0.932
62464             0.569083             0.604323           0.942
67584             0.614141             0.656674           0.935
72704             0.653114             0.705179           0.926
77824             0.699226             0.756339           0.924
Operator layernorm Performance Test (torch.float32)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024             0.0150514            0.0633936           0.237
6144             0.0939577            0.0865922            1.09
11264              0.16951             0.202276           0.838
16384             0.242892               0.3167           0.767
21504             0.318454             0.416455           0.765
26624             0.397542             0.517825           0.768
31744             0.471809             0.614384           0.768
36864             0.561043             0.718137           0.781
41984             0.622455             0.814443           0.764
47104             0.694413              0.90966           0.763
52224             0.771752              1.00866           0.765
57344             0.849227              1.10915           0.766
62464              0.92143              1.20739           0.763
67584              1.00197              1.31215           0.764
72704              1.07565              1.40777           0.764
77824               1.1676              1.50991           0.773
Operator layernorm Performance Test (torch.bfloat16)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024             0.0101285            0.0633092            0.16
6144             0.0543299            0.0636106           0.854
11264              0.10011            0.0803981            1.25
16384             0.151164             0.136883             1.1
21504             0.195308             0.195153             1.0
26624             0.247512             0.255802           0.968
31744             0.288948             0.308058           0.938
36864             0.339835              0.35763            0.95
41984             0.389078             0.408352           0.953
47104             0.431145             0.458213           0.941
52224             0.475008             0.506602           0.938
57344             0.515749             0.556207           0.927
62464             0.563612             0.607139           0.928
67584             0.612329             0.656966           0.932
72704             0.652153             0.706296           0.923
77824             0.700632             0.757162           0.925

cuda mode timing

test_reduction_perf.py Operator layernorm Performance Test (torch.float16)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.013312              0.01024             1.3
6144               0.05632             0.038912            1.45
11264             0.101376             0.080896            1.25
16384             0.149504              0.13824            1.08
21504              0.19968             0.195584            1.02
26624             0.246784             0.257024            0.96
31744             0.287744             0.309248            0.93
36864             0.342016             0.359424           0.952
41984             0.390144               0.4096           0.953
47104              0.43008             0.459776           0.935
52224             0.475136             0.508928           0.934
57344             0.519168              0.55808            0.93
62464             0.572416             0.606208           0.944
67584              0.61952             0.658432           0.941
72704             0.653312              0.70656           0.925
77824             0.710656              0.75776           0.938
Operator layernorm Performance Test (torch.float32)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.017408              0.01536            1.13
6144              0.096256             0.089088            1.08
11264             0.171008             0.203776           0.839
16384             0.244736             0.318464           0.768
21504             0.320512             0.418816           0.765
26624             0.397312             0.520192           0.764
31744             0.470016             0.615424           0.764
36864             0.557056             0.720896           0.773
41984             0.622592             0.816128           0.763
47104              0.69632             0.910336           0.765
52224             0.776192              1.01069           0.768
57344             0.848896              1.11002           0.765
62464             0.922624              1.20832           0.764
67584               1.0025              1.30867           0.766
72704              1.07725              1.40698           0.766
77824              1.16736              1.51245           0.772
Operator layernorm Performance Test (torch.bfloat16)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.013312              0.01024             1.3
6144              0.057344             0.038912            1.47
11264             0.101376             0.082944            1.22
16384             0.149504              0.13824            1.08
21504             0.198656             0.197632            1.01
26624             0.246624             0.257024            0.96
31744             0.287744             0.309248            0.93
36864              0.33792             0.359424            0.94
41984             0.387072             0.410624           0.943
47104             0.438272             0.459776           0.953
52224             0.475136             0.509056           0.933
57344             0.519168             0.557056           0.932
62464             0.564224             0.610304           0.924
67584             0.610304             0.657408           0.928
72704              0.65024             0.707584           0.919
77824             0.703488             0.758816           0.927

After

cpu-mode timing

test_reduction_perf.py Operator layernorm Performance Test (torch.float16)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024             0.0100327             0.066107           0.152
6144             0.0539818             0.124814           0.432
11264            0.0996938             0.074445            1.34
16384             0.145836             0.111036            1.31
21504             0.193913             0.151223            1.28
26624             0.243227             0.188586            1.29
31744             0.287275             0.226547            1.27
36864             0.340659              0.26407            1.29
41984             0.385496              0.30374            1.27
47104             0.430817             0.340443            1.27
52224             0.474201             0.378162            1.25
57344             0.516725             0.416777            1.24
62464             0.567326              0.45388            1.25
67584             0.606309             0.492827            1.23
72704             0.651885             0.529122            1.23
77824              0.71089             0.568332            1.25
Operator layernorm Performance Test (torch.float32)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024             0.0153615            0.0663347           0.232
6144             0.0943303            0.0731363            1.29
11264             0.168936             0.149472            1.13
16384             0.243501             0.223247            1.09
21504             0.318337             0.303555            1.05
26624             0.394454             0.380226            1.04
31744             0.470876             0.455038            1.03
36864             0.555541             0.531961            1.04
41984             0.622823             0.607893            1.02
47104             0.695291             0.680414            1.02
52224             0.774846             0.761748            1.02
57344             0.849177             0.838091            1.01
62464             0.922314             0.909388            1.01
67584              1.00176             0.991296            1.01
72704              1.07602              1.06251            1.01
77824               1.1634              1.14813            1.01
Operator layernorm Performance Test (torch.bfloat16)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024             0.0101304            0.0635226           0.159
6144             0.0549207            0.0625584           0.878
11264             0.099654            0.0693154            1.44
16384             0.146019             0.106422            1.37
21504             0.194268             0.145978            1.33
26624             0.243556             0.183239            1.33
31744             0.289899             0.220377            1.32
36864             0.340016             0.258661            1.31
41984             0.387771              0.29838             1.3
47104             0.432248             0.335079            1.29
52224             0.472394             0.373223            1.27
57344             0.511883             0.410697            1.25
62464             0.571961             0.448678            1.27
67584             0.603934             0.486935            1.24
72704              0.65877             0.522965            1.26
77824             0.704553             0.562981            1.25

cuda-mode timing

test_reduction_perf.py Operator layernorm Performance Test (torch.float16)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.013312              0.01024             1.3
6144              0.055296             0.039936            1.38
11264             0.101376               0.0768            1.32
16384             0.149504              0.11264            1.33
21504              0.19968             0.153408             1.3
26624              0.24576             0.190464            1.29
31744             0.287744             0.228352            1.26
36864             0.340992              0.26624            1.28
41984             0.390144             0.305152            1.28
47104             0.431008             0.342016            1.26
52224              0.47616             0.379904            1.25
57344             0.519168             0.417792            1.24
62464             0.572416              0.45568            1.26
67584              0.61952             0.493568            1.26
72704             0.653312             0.531456            1.23
77824             0.709664             0.571392            1.24
Operator layernorm Performance Test (torch.float32)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.017408             0.014336            1.21
6144              0.096256             0.074752            1.29
11264             0.171008             0.151552            1.13
16384             0.244736             0.226304            1.08
21504              0.32032             0.306176            1.05
26624             0.397312             0.381952            1.04
31744             0.470016              0.45568            1.03
36864             0.557056             0.531456            1.05
41984             0.622592             0.610304            1.02
47104              0.69632             0.683008            1.02
52224             0.776192             0.761856            1.02
57344             0.848896             0.839648            1.01
62464             0.923648              0.91136            1.01
67584               1.0025              0.99136            1.01
72704              1.07725              1.06394            1.01
77824              1.16736              1.14893            1.02
Operator layernorm Performance Test (torch.bfloat16)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.013312             0.011264            1.18
6144              0.057344             0.043008            1.33
11264             0.101376              0.07168            1.41
16384             0.149504              0.10752            1.39
21504             0.198656             0.147456            1.35
26624              0.24576             0.185344            1.33
31744             0.287744             0.222208            1.29
36864              0.33792              0.26112            1.29
41984             0.387072             0.300032            1.29
47104              0.44032             0.336896            1.31
52224             0.475136             0.374784            1.27
57344             0.519168             0.412672            1.26
62464             0.564224              0.45056            1.25
67584             0.610304             0.488448            1.25
72704              0.65024             0.524288            1.24
77824             0.703488             0.565248            1.24

Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great Job!👍

@iclementine iclementine merged commit da86496 into FlagOpen:master Sep 20, 2024
4 checks passed
DuanYaQi pushed a commit that referenced this pull request Oct 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants