Skip to content

Commit

Permalink
feat: add rsqrt composite rule (#51432)
Browse files Browse the repository at this point in the history
* feat: add relu composite rule

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, add python api of relu

* feat: add relu composite rule, commit hook

* fix: maximum type error & ban cinn test

* fix: maximum input sequence bugs

* resolve conflicts

* fix: code style bugs

* add: relu fp16 test

* feat: add rsqrt composite rule

* feat: add rsqrt composite rule

* resolve conflicts of composite rule

* fix: delete check eager
  • Loading branch information
Miracle2333 authored Mar 15, 2023
1 parent 0e492e4 commit c9ca7c3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,7 @@ def test_check_grad(self):
class TestRsqrt(TestActivation):
def setUp(self):
self.op_type = "rsqrt"
self.prim_op_type = "comp"
self.python_api = paddle.rsqrt
self.init_dtype()
self.init_shape()
Expand All @@ -1248,14 +1249,23 @@ def setUp(self):

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.enable_cinn = True

def init_shape(self):
self.shape = [10, 12]

def test_check_output(self):
self.check_output(check_prim=True)

def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.0005)
self.check_grad(
['X'],
'Out',
max_relative_error=0.0005,
check_prim=True,
)


'''
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,11 @@ def unsqueeze_composite(x, axis):
)
out = reshape(x, x_shape)
return [out, None]


@REGISTER_COMPOSITE('rsqrt')
def rsqrt_composite(x):
"""define composite rule of op rsqrt."""
# rsqrt(x) = x^(-0.5)
y = full(x.shape, -0.5, x.dtype)
return pow(x, y)

0 comments on commit c9ca7c3

Please sign in to comment.