Skip to content

Commit

Permalink
[PIR] Open test_case ut (#60721)
Browse files Browse the repository at this point in the history
* fix

* fix
  • Loading branch information
zhangbo9674 authored Jan 11, 2024
1 parent 57fff3a commit 55558f1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/legacy_test/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ def fn_3():
np.testing.assert_allclose(res[4], 2, rtol=1e-05)
self.assertEqual(res[4].shape, ())

# Todo(zhangbo): grad_list can not find dx in oir mode
# @test_with_pir_api
@test_with_pir_api
def test_0d_tensor_backward(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
x.stop_gradient = False
x.persistable = True
pred = paddle.full(shape=[], dtype='bool', fill_value=0)
# pred is False, so out = -x
out = paddle.static.nn.case(
Expand Down

0 comments on commit 55558f1

Please sign in to comment.