Skip to content

Commit

Permalink
[NPU] Support npu op flatten_contiguous_range_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
WorgenZhang committed Aug 11, 2021
1 parent 6a9fac1 commit 02263e4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
33 changes: 33 additions & 0 deletions paddle/fluid/operators/flatten_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@ class FlattenContiguousRangeNPUKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class FlattenContiguousRangeGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));

auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());

d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<paddle::platform::NPUDeviceContext>(), d_x);
d_x->Resize(x_dims);
}
};

} // namespace operators
} // namespace paddle

Expand Down Expand Up @@ -110,3 +129,17 @@ REGISTER_OP_NPU_KERNEL(
int8_t>,
ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext,
int64_t>);
REGISTER_OP_NPU_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
float>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
double>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
uint8_t>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
int>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
int64_t>);
18 changes: 15 additions & 3 deletions python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_check_output(self):
self.check_output_with_place(self.place, no_check_set=["XShape"])

def test_check_grad(self):
pass
self.check_grad_with_place(self.place, ["X"], "Out")

def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
Expand Down Expand Up @@ -163,20 +163,23 @@ def init_attrs(self):
}


class TestFlattenOp_int(TestFlattenOp):
class TestFlattenOp_int32(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.int
self.dtype = np.int32

def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}

def test_check_grad(self):
pass


class TestFlattenOp_uint8(TestFlattenOp):
def init_test_case(self):
Expand All @@ -192,6 +195,9 @@ def init_attrs(self):
"stop_axis": self.stop_axis
}

def test_check_grad(self):
pass


class TestFlattenOp_int8(TestFlattenOp):
def init_test_case(self):
Expand All @@ -207,6 +213,9 @@ def init_attrs(self):
"stop_axis": self.stop_axis
}

def test_check_grad(self):
pass


class TestFlattenOp_int64(TestFlattenOp):
def init_test_case(self):
Expand All @@ -222,6 +231,9 @@ def init_attrs(self):
"stop_axis": self.stop_axis
}

def test_check_grad(self):
pass


class TestFlatten2OpError(unittest.TestCase):
def test_errors(self):
Expand Down

1 comment on commit 02263e4

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.