Skip to content

Commit

Permalink
add static mode backward test
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 committed Nov 17, 2023
1 parent c833770 commit 6cc1f71
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions test/legacy_test/test_set_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1978,5 +1978,59 @@ def test_check_grad(self):
self.check_grad_with_place(place, ['Input'], 'Out', check_dygraph=False)


class TestSetValueWithScalarInStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shape = (10, 2)
self.exe = paddle.static.Executor()
self.train_program = paddle.static.Program()
self.startup_program = paddle.static.Program()

def test_value_input_is_scalar(self):
with paddle.static.program_guard(
self.train_program, self.startup_program
):
x = paddle.ones(self.shape)
x.stop_gradient = False
y = x * 1

# mock test case x[0, 0] = 10 with no ValueTensor input
inputs = {
'Input': y,
}
attrs = {
'axes': [0, 1],
'starts': [0, 0],
'ends': [1, 1],
'steps': [1, 1],
'values': [10],
'shape': [1],
}

helper = LayerHelper("set_value")
out = helper.create_variable_for_type_inference(dtype=y.dtype)

helper.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': out},
attrs=attrs,
)

np_data = np.ones(self.shape).astype('float32')

paddle.static.append_backward(out.sum())
res = self.exe.run(
self.train_program, fetch_list=[out, x.grad_name]
)

np_data[0, 0] = 10
expected_x_grad = np.ones(self.shape)
expected_x_grad[0, 0] = 0

np.testing.assert_array_equal(res[0], np_data)
np.testing.assert_array_equal(res[1], expected_x_grad)


if __name__ == '__main__':
unittest.main()

0 comments on commit 6cc1f71

Please sign in to comment.