Skip to content

Commit

Permalink
[Feature] Add Ascend support for RoIPoolGrad op (open-mmlab#2569)
Browse files Browse the repository at this point in the history
* add roipoolgrad adapter

* amend
  • Loading branch information
xinlianglalala authored and CokeDong committed Apr 6, 2023
1 parent b46d838 commit b0f10a3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
31 changes: 30 additions & 1 deletion mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output,
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor(
{}, rois.options().dtype(at::kInt), rois);

OpCommand cmd;
cmd.Name("RoiPoolingWithArgMax")
.Input(input)
Expand All @@ -27,8 +26,38 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output,
.Run();
}

void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
int64_t pooled_height_64 = pooled_height;
int64_t pooled_width_64 = pooled_width;
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor(
{}, rois.options().dtype(at::kInt), rois);
at::Tensor x = at::ones_like(grad_input);
OpCommand cmd;
cmd.Name("RoiPoolingGradWithArgMax")
.Input(grad_output)
.Input(x)
.Input(rois)
.Input(roi_actual_num)
.Input(argmax)
.Output(grad_input)
.Attr("pooled_h", pooled_height_64)
.Attr("pooled_w", pooled_width_64)
.Attr("spatial_scale_h", spatial_scale)
.Attr("spatial_scale_w", spatial_scale)
.Attr("pool_channel", pooled_channel)
.Run();
}

void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale);

void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);

REGISTER_NPU_IMPL(roi_pool_forward_impl, roi_pool_forward_npu);
REGISTER_NPU_IMPL(roi_pool_backward_impl, roi_pool_backward_npu);
25 changes: 9 additions & 16 deletions tests/test_ops/test_roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,13 @@ def _test_roipool_allclose(self, device, dtype=torch.float):
np_output = np.array(output[0])
np_grad = np.array(output[1])

if device == 'npu':
import torch_npu # noqa: F401
x = torch.tensor(np_input, dtype=dtype).npu()
rois = torch.tensor(np_rois, dtype=dtype).npu()
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
else:
x = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device=device)
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)
x = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device=device)
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)

@pytest.mark.parametrize('device', [
pytest.param(
Expand All @@ -103,8 +96,8 @@ def _test_roipool_allclose(self, device, dtype=torch.float):
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
reason='MLU, NPU does not support for 64-bit floating point')),
torch.half
])
def test_roipool_allclose(self, device, dtype):
Expand Down

0 comments on commit b0f10a3

Please sign in to comment.