Skip to content

Commit

Permalink
Modify LayerNorm Composite Rule (PaddlePaddle#52712)
Browse files Browse the repository at this point in the history
* [Do NOT merge] Expr PR on Composite

* Expr PR on Composite

* Revert some compsite experiment

* Remove unnecessary composite code

* Add rsqrt as sub primitives
  • Loading branch information
zhhsplendid authored Apr 12, 2023
1 parent b0f17d0 commit a206056
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
var_tmp1 = difference * difference
variance = mean(var_tmp1, axis=axis, keepdim=True)
var_tmp3 = variance + epsilon
sqrt_var = sqrt(var_tmp3)
out = difference / sqrt_var
rsqrt_var = rsqrt(var_tmp3)
out = difference * rsqrt_var

if scale is not None:
scale = reshape(scale, x.shape[begin_norm_axis:])
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/incubate/autograd/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from paddle.tensor import pow # noqa: F401
from paddle.tensor import prod # noqa: F401
from paddle.tensor import reshape # noqa: F401
from paddle.tensor import rsqrt # noqa: F401
from paddle.tensor import sign # noqa: F401
from paddle.tensor import sin # noqa: F401
from paddle.tensor import sinh # noqa: F401
Expand Down Expand Up @@ -117,6 +118,7 @@
'ones',
'zeros',
'sqrt',
'rsqrt',
]
others = [
Expand Down

0 comments on commit a206056

Please sign in to comment.