Skip to content

Commit

Permalink
[PIR]Open more tests for bernoulli and celu (#60706)
Browse files Browse the repository at this point in the history
* bernoulli && celu

* celu test_error
  • Loading branch information
changeyoung98 authored Jan 11, 2024
1 parent 2c56dd4 commit bed33c3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 3 additions & 2 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3378,6 +3378,7 @@ def test_dygraph_api(self):
for r in [out1, out2]:
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)

@test_with_pir_api
def test_errors(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -4921,7 +4922,7 @@ def test_check_grad(self):
create_test_act_fp16_class(TestRelu6)
create_test_act_fp16_class(TestSoftRelu, check_dygraph=False)
create_test_act_fp16_class(TestELU)
create_test_act_fp16_class(TestCELU)
create_test_act_fp16_class(TestCELU, check_pir=True)
create_test_act_fp16_class(TestReciprocal, check_pir=True)
create_test_act_fp16_class(TestLog, check_prim=True, check_pir=True)
if core.is_compiled_with_rocm():
Expand Down Expand Up @@ -5090,7 +5091,7 @@ def test_check_grad(self):
create_test_act_bf16_class(TestRelu6)
create_test_act_bf16_class(TestSoftRelu, check_dygraph=False)
create_test_act_bf16_class(TestELU)
create_test_act_bf16_class(TestCELU)
create_test_act_bf16_class(TestCELU, check_pir=True)
create_test_act_bf16_class(TestReciprocal, check_pir=True)
create_test_act_bf16_class(TestLog, check_prim=True, check_pir=True)
if core.is_compiled_with_rocm():
Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/test_bernoulli_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def init_dtype(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place_customized(self.verify_output, place)
self.check_output_with_place_customized(
self.verify_output, place, check_pir=True
)

def init_test_case(self):
self.x = convert_float_to_uint16(
Expand Down

0 comments on commit bed33c3

Please sign in to comment.