From 5e2a3dbf836cc30fb6c2b66002e8965b862c91a3 Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 2 Jan 2024 15:57:45 +0800 Subject: [PATCH] [Prim][PIR] support roll, gather, scatter, scatter_nd_add op backward in pir prim (#60481) * prim gather op backward * prim scatter op backward * prim roll op backward * prim scatter_nd op backward --- paddle/fluid/primitive/codegen/gen.py | 4 + paddle/fluid/primitive/rule/vjp/details.h | 100 ++++++++++++++++++++++ test/legacy_test/test_gather_op.py | 11 ++- test/legacy_test/test_roll_op.py | 18 +++- test/legacy_test/test_scatter_nd_op.py | 26 +++++- test/legacy_test/test_scatter_op.py | 38 ++++++-- 6 files changed, 183 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 01e760c2b33b2..005eae2959343 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -80,13 +80,17 @@ 'sum_grad', 'cast_grad', 'reshape_grad', + 'roll_grad', 'split_grad', 'transpose_grad', 'concat_grad', 'expand_grad', + 'gather_grad', 'gather_nd_grad', 'pad_grad', 'max_grad', + 'scatter_grad', + 'scatter_nd_add_grad', 'slice_grad', 'tile_grad', 'topk_grad', diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 60d51c6014627..1be68ba043e19 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -243,6 +243,23 @@ void reshape_grad(const Tensor& xshape, } } +template +void roll_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& shifts, + const std::vector& axis, + Tensor* x_grad) { + if (x_grad) { + auto shifts_ = shifts.GetData(); + int64_t nums = shifts_.size(); + for (int64_t i = 0; i < nums; i++) { + shifts_[i] = 0 - shifts_[i]; + } + auto x_grad_output = roll(out_grad, shifts_, axis); + set_output(x_grad_output, x_grad); + } +} + template void transpose_grad(const Tensor& grad_out, const std::vector& perm, @@ -262,6 +279,43 @@ void transpose_grad(const Tensor& grad_out, } } +template +void scatter_grad(const Tensor& index, + const Tensor& updates, + const Tensor& out_grad, + bool overwrite, + Tensor* x_grad, + Tensor* updates_grad) { + if (x_grad) { + auto zero_tensor = + full(common::vectorize(updates.dims()), 0.0, updates.dtype()); + auto tmp_grad = scatter(out_grad, index, zero_tensor, false); + set_output(tmp_grad, x_grad); + } + + if (updates_grad) { + Scalar tmp_zero = 0; + auto tmp_updates_grad = gather(out_grad, index, tmp_zero); + set_output(tmp_updates_grad, updates_grad); + } +} + +template +void scatter_nd_add_grad(const Tensor& index, + const Tensor& updates, + const Tensor& out_grad, + Tensor* x_grad, + Tensor* updates_grad) { + if (x_grad) { + by_pass(out_grad, x_grad); + } + if (updates_grad) { + // Gradient by Gather: dUpdates = dO[Ids] + auto tmp_updates_grad = gather_nd(out_grad, index); + set_output(tmp_updates_grad, updates_grad); + } +} + template void sin_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { auto x_grad_tmp = cos(x) * out_grad; @@ -818,6 +872,52 @@ void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { } } +template +void gather_grad(const Tensor& x, + const Tensor& index, + const Tensor& out_grad, + const Scalar& axis, + Tensor* grad_x) { + auto zero_tensor = full(common::vectorize(x.dims()), 0.0, x.dtype()); + std::vector tmp_perm; + + // change axis to rank 0 + int axis_value = axis.to(); + tmp_perm.push_back(axis_value); + // make other ranks + for (int i = 0; i < x.dims().size(); ++i) { + if (i != axis_value) { + tmp_perm.push_back(i); + } + } + std::vector reverse_perm(tmp_perm); + // make origin ranks + for (int i = 0; i < static_cast(tmp_perm.size()); ++i) { + if (tmp_perm[i] >= 0) { + reverse_perm[tmp_perm[i]] = i; + } else { + reverse_perm[tmp_perm[i] + tmp_perm.size()] = i; + } + } + + // transpose out_grad and zero grad to target rank. + auto tmp_zero_x_grad = zero_tensor; + auto tmp_out_grad = out_grad; + if (zero_tensor.dims().size() > 0) { + tmp_zero_x_grad = transpose(zero_tensor, tmp_perm); + } + if (out_grad.dims().size() > 0) { + tmp_out_grad = transpose(out_grad, tmp_perm); + } + // scatter grad to grad_x + auto tmp_grad_x = scatter(tmp_zero_x_grad, index, tmp_out_grad, false); + auto tmp_grad_x_tranposed = tmp_grad_x; + if (tmp_grad_x.dims().size() > 0) { + tmp_grad_x_tranposed = transpose(tmp_grad_x, reverse_perm); + } + set_output(tmp_grad_x_tranposed, grad_x); +} + template void gather_nd_grad(const Tensor& x, const Tensor& index, diff --git a/test/legacy_test/test_gather_op.py b/test/legacy_test/test_gather_op.py index 3ebb2de7b8560..f37af3a62ddb9 100644 --- a/test/legacy_test/test_gather_op.py +++ b/test/legacy_test/test_gather_op.py @@ -45,7 +45,9 @@ def test_check_output(self): self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) + self.check_grad( + ['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True + ) def config(self): """ @@ -119,7 +121,12 @@ def test_check_output(self): def test_check_grad(self): self.check_grad_with_place( - paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True, check_pir=True + paddle.CUDAPlace(0), + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_roll_op.py b/test/legacy_test/test_roll_op.py index 5512e248acbb1..e6057705e4987 100644 --- a/test/legacy_test/test_roll_op.py +++ b/test/legacy_test/test_roll_op.py @@ -52,7 +52,9 @@ def test_check_output(self): self.check_output(check_prim=True, check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) + self.check_grad( + ['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True + ) class TestRollOpCase2(TestRollOp): @@ -139,7 +141,12 @@ def test_check_output(self): def test_check_grad_normal(self): self.check_grad_with_place( - self.place, ['X'], 'Out', check_prim=True, check_pir=True + self.place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -163,7 +170,12 @@ def test_check_output(self): def test_check_grad_normal(self): self.check_grad_with_place( - self.place, ['X'], 'Out', check_prim=True, check_pir=True + self.place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_scatter_nd_op.py b/test/legacy_test/test_scatter_nd_op.py index e9e541e09af67..6290c0b485c4f 100644 --- a/test/legacy_test/test_scatter_nd_op.py +++ b/test/legacy_test/test_scatter_nd_op.py @@ -98,7 +98,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -133,7 +137,12 @@ def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -176,7 +185,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -211,7 +224,12 @@ def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_scatter_op.py b/test/legacy_test/test_scatter_op.py index d44982c6321d0..61b6b0b45f308 100644 --- a/test/legacy_test/test_scatter_op.py +++ b/test/legacy_test/test_scatter_op.py @@ -57,7 +57,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ["X", "Updates"], "Out", check_prim=True, check_pir=True + ["X", "Updates"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -92,6 +96,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -128,7 +133,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ["X", "Updates"], "Out", check_prim=True, check_pir=True + ["X", "Updates"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -163,6 +172,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -202,7 +212,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ["X", "Updates"], "Out", check_prim=True, check_pir=True + ["X", "Updates"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -237,6 +251,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -284,6 +299,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -356,6 +372,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -412,7 +429,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -447,6 +468,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -494,6 +516,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -550,7 +573,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ["X", "Updates"], "Out", check_prim=True, check_pir=True + ["X", "Updates"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -585,6 +612,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, )