From 02263e4838ec419326f43612915c3004355eefcf Mon Sep 17 00:00:00 2001 From: WorgenZhang Date: Wed, 11 Aug 2021 14:24:50 +0800 Subject: [PATCH] [NPU] Support npu op flatten_contiguous_range_grad --- paddle/fluid/operators/flatten_op_npu.cc | 33 +++++++++++++++++++ .../test_flatten_contiguous_range_op_npu.py | 18 ++++++++-- 2 files changed, 48 insertions(+), 3 deletions(-) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py diff --git a/paddle/fluid/operators/flatten_op_npu.cc b/paddle/fluid/operators/flatten_op_npu.cc index 1569760fe3b96..9252716f3acfc 100644 --- a/paddle/fluid/operators/flatten_op_npu.cc +++ b/paddle/fluid/operators/flatten_op_npu.cc @@ -78,6 +78,25 @@ class FlattenContiguousRangeNPUKernel : public framework::OpKernel { } }; +template +class FlattenContiguousRangeGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); + d_x->Resize(x_dims); + } +}; + } // namespace operators } // namespace paddle @@ -110,3 +129,17 @@ REGISTER_OP_NPU_KERNEL( int8_t>, ops::FlattenContiguousRangeNPUKernel); +REGISTER_OP_NPU_KERNEL( + flatten_contiguous_range_grad, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py old mode 100644 new mode 100755 index 88e711dcf068e..742d156c7f5f1 --- a/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py @@ -49,7 +49,7 @@ def test_check_output(self): self.check_output_with_place(self.place, no_check_set=["XShape"]) def test_check_grad(self): - pass + self.check_grad_with_place(self.place, ["X"], "Out") def init_test_case(self): self.in_shape = (3, 2, 5, 4) @@ -163,13 +163,13 @@ def init_attrs(self): } -class TestFlattenOp_int(TestFlattenOp): +class TestFlattenOp_int32(TestFlattenOp): def init_test_case(self): self.in_shape = (3, 2, 5, 4) self.start_axis = 0 self.stop_axis = 1 self.new_shape = (6, 5, 4) - self.dtype = np.int + self.dtype = np.int32 def init_attrs(self): self.attrs = { @@ -177,6 +177,9 @@ def init_attrs(self): "stop_axis": self.stop_axis } + def test_check_grad(self): + pass + class TestFlattenOp_uint8(TestFlattenOp): def init_test_case(self): @@ -192,6 +195,9 @@ def init_attrs(self): "stop_axis": self.stop_axis } + def test_check_grad(self): + pass + class TestFlattenOp_int8(TestFlattenOp): def init_test_case(self): @@ -207,6 +213,9 @@ def init_attrs(self): "stop_axis": self.stop_axis } + def test_check_grad(self): + pass + class TestFlattenOp_int64(TestFlattenOp): def init_test_case(self): @@ -222,6 +231,9 @@ def init_attrs(self): "stop_axis": self.stop_axis } + def test_check_grad(self): + pass + class TestFlatten2OpError(unittest.TestCase): def test_errors(self):