-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【complex op No.8】add complex support for Rsqrt #63720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f118485
a468601
1fd3797
19a75d8
0aa378f
2565b9a
c957264
7996f85
1beb392
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1870,6 +1870,11 @@ def setUp(self): | |
|
|
||
| np.random.seed(1024) | ||
| x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) | ||
| if self.dtype == np.complex64 or self.dtype == np.complex128: | ||
| x = ( | ||
| np.random.uniform(0.1, 1, self.shape) | ||
| + 1j * np.random.uniform(0.1, 1, self.shape) | ||
| ).astype(self.dtype) | ||
| out = 1.0 / np.sqrt(x) | ||
|
|
||
| self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} | ||
|
|
@@ -1910,6 +1915,54 @@ def if_enable_cinn(self): | |
| self.enable_cinn = False | ||
|
|
||
|
|
||
| class TestRsqrt_Complex64(TestRsqrt): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 没有complex128类型的test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| def init_dtype(self): | ||
| self.dtype = np.complex64 | ||
|
|
||
| def test_check_grad(self): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GPU单测没运行,加下在GPU下运行的单测 |
||
| self.check_grad( | ||
| ['X'], | ||
| 'Out', | ||
| check_pir=True, | ||
| max_relative_error=0.007, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个设置有点大,看能否改小点 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 测试的时候误差是0.006多,这里设成了0.007 |
||
| check_pir_onednn=self.check_pir_onednn, | ||
| ) | ||
|
|
||
| def test_api_complex(self): | ||
| with dynamic_guard(): | ||
| for device in devices: | ||
| if device == 'cpu' or ( | ||
| device == 'gpu' and paddle.is_compiled_with_cuda() | ||
| ): | ||
| np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=self.dtype) | ||
| x = paddle.to_tensor(np_x, dtype=self.dtype, place=device) | ||
| y = paddle.rsqrt(x) | ||
| x_expect = 1.0 / np.sqrt(np_x) | ||
| np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) | ||
|
|
||
| def test_grad_grad(self): | ||
| with dynamic_guard(): | ||
| x_numpy = ( | ||
| np.random.uniform(0.1, 1, self.shape) | ||
| + 1j * np.random.uniform(0.1, 1, self.shape) | ||
| ).astype(self.dtype) | ||
|
|
||
| expected_ddx = 3.0 / 4 * np.conj(np.power(x_numpy, -2.5)) | ||
|
|
||
| x = paddle.to_tensor(x_numpy, stop_gradient=False) | ||
| y = paddle.rsqrt(x) | ||
| dx = paddle.grad( | ||
| outputs=[y], inputs=[x], create_graph=True, retain_graph=True | ||
| )[0] | ||
| ddx = paddle.grad(outputs=[dx], inputs=[x], retain_graph=True)[0] | ||
| np.testing.assert_allclose(ddx.numpy(), expected_ddx, rtol=1e-3) | ||
|
|
||
|
|
||
| class TestRsqrt_Complex128(TestRsqrt_Complex64): | ||
| def init_dtype(self): | ||
| self.dtype = np.complex128 | ||
|
|
||
|
|
||
| class TestAbs(TestActivation): | ||
| def setUp(self): | ||
| self.op_type = "abs" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring也添加上复数的两个类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.