Skip to content
26 changes: 26 additions & 0 deletions test/optests_failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@
"_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit",
"_version": 1,
"data": {
"torchvision::ps_roi_align": {
"TestPSRoIAlign.test_aot_dispatch_dynamic__test_mps_error_inputs": {
"comment": "RuntimeError: MPS does not support ps_roi_align backward with float16 inputs",
"status": "xfail"
},
"TestPSRoIAlign.test_autograd_registration__test_backward[True-mps-0]": {
"comment": "NotImplementedError: autograd_registration_check: NYI devices other than CPU/CUDA, got {'mps'}",
"status": "xfail"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @zou3519

I just opened pytorch/pytorch#111797 which I believe could be a fix for the problem I'm facing here:

This test (and a bunch of others)

vision/test/test_ops.py

Lines 186 to 189 in 3fb88b3

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, seed, device, contiguous, deterministic=False):

is parametrized over cpu, CUDA and MPS, and fails on MPS.

I'm trying to xfail the MPS parametrization as above (lines 10-12), but optests complains with:

E RuntimeError: In failures dict, got test name 'TestPSRoIAlign.test_autograd_registration__test_backward[True-mps-0]'. We parsed this as running test 'test_autograd_registration' on 'test_backward[True-mps-0]', but test_backward[True-mps-0] does not exist on the TestCase 'TestPSRoIAlign]. Maybe you need to change the test name?

The problem is, if I replace TestPSRoIAlign.test_autograd_registration__test_backward[True-mps-0] with TestPSRoIAlign.test_autograd_registration__test_backward in line 10, then I'm getting and "unexpected success":

E torch.testing._internal.optests.generate_tests.OpCheckError: generate_opcheck_tests: Unexpected success for operator torchvision::ps_roi_align on test TestPSRoIAlign.test_autograd_registration__test_backward. This may mean that you have fixed this test failure. Please rerun the test with PYTORCH_OPCHECK_ACCEPT=1 to automatically update the test runner or manually remove the expected failure in the failure dict at /home/nicolashug/dev/vision/test/optests_failures_dict.jsonFor more details, see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug can the tests be run under unittest, or are they pytest only?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pytest only

},
"TestPSRoIAlign.test_autograd_registration__test_mps_error_inputs": {
"comment": "NotImplementedError: autograd_registration_check: NYI devices other than CPU/CUDA, got {'mps'}",
"status": "xfail"
},
"TestPSRoIAlign.test_faketensor__test_backward[True-mps-0]": {
"comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!",
"status": "xfail"
},
"TestPSRoIAlign.test_faketensor__test_forward[x_dtype0-True-mps]": {
"comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!",
"status": "xfail"
},
"TestPSRoIAlign.test_faketensor__test_mps_error_inputs": {
"comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!",
"status": "xfail"
}
},
"torchvision::roi_align": {
"TestRoIAlign.test_aot_dispatch_dynamic__test_mps_error_inputs": {
"comment": "RuntimeError: MPS does not support roi_align backward with float16 inputs",
Expand Down
14 changes: 13 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ class RoIOpTester(ABC):
torch.float32,
torch.float64,
),
ids=str,
# ids=str,
)
@pytest.mark.opcheck_only_one()
def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
if device == "mps" and x_dtype is torch.float64:
pytest.skip("MPS does not support float64")
Expand Down Expand Up @@ -186,6 +187,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.opcheck_only_one()
def test_backward(self, seed, device, contiguous, deterministic=False):
atol = self.mps_backward_atol if device == "mps" else 1e-05
dtype = self.mps_dtype if device == "mps" else self.dtype
Expand Down Expand Up @@ -228,6 +230,7 @@ def func(z):
@needs_cuda
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
@pytest.mark.opcheck_only_one()
def test_autocast(self, x_dtype, rois_dtype):
with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
Expand Down Expand Up @@ -659,6 +662,15 @@ def test_boxes_shape(self):
self._helper_boxes_shape(ops.ps_roi_align)


optests.generate_opcheck_tests(
testcase=TestPSRoIAlign,
namespaces=["torchvision"],
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
additional_decorators=[],
test_utils=OPTESTS,
)


class TestMultiScaleRoIAlign:
def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
if fmap_names is None:
Expand Down
47 changes: 46 additions & 1 deletion torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp
),
)
num_rois = rois.size(0)
_, channels, height, width = input.size()
channels = input.size(1)
return input.new_empty((num_rois, channels, pooled_height, pooled_width))


Expand All @@ -51,6 +51,51 @@ def meta_roi_align_backward(
return grad.new_empty((batch_size, channels, height, width))


@register_meta("ps_roi_align")
def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
channels = input.size(1)
torch._check(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width",
)

num_rois = rois.size(0)
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")


@register_meta("_ps_roi_align_backward")
def meta_ps_roi_align_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width,
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))


@torch._custom_ops.impl_abstract("torchvision::nms")
def meta_nms(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
Expand Down
54 changes: 27 additions & 27 deletions torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ class PSROIAlignFunction
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
ctx->saved_data["input_shape"] = input.sym_sizes();
at::AutoDispatchBelowADInplaceOrView g;
auto result = ps_roi_align(
auto result = ps_roi_align_symint(
input,
rois,
spatial_scale,
Expand All @@ -48,19 +48,19 @@ class PSROIAlignFunction
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = detail::_ps_roi_align_backward(
auto input_shape = ctx->saved_data["input_shape"].toList();
auto grad_in = detail::_ps_roi_align_backward_symint(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
ctx->saved_data["pooled_height"].toSymInt(),
ctx->saved_data["pooled_width"].toSymInt(),
ctx->saved_data["sampling_ratio"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
input_shape[0].get().toSymInt(),
input_shape[1].get().toSymInt(),
input_shape[2].get().toSymInt(),
input_shape[3].get().toSymInt());

return {
grad_in,
Expand All @@ -82,15 +82,15 @@ class PSROIAlignBackwardFunction
const torch::autograd::Variable& rois,
const torch::autograd::Variable& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width) {
at::AutoDispatchBelowADInplaceOrView g;
auto grad_in = detail::_ps_roi_align_backward(
auto grad_in = detail::_ps_roi_align_backward_symint(
grad,
rois,
channel_mapping,
Expand All @@ -117,8 +117,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio) {
auto result = PSROIAlignFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
Expand All @@ -131,13 +131,13 @@ at::Tensor ps_roi_align_backward_autograd(
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width) {
return PSROIAlignBackwardFunction::apply(
grad,
rois,
Expand Down
49 changes: 47 additions & 2 deletions torchvision/csrc/ops/ps_roi_align.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}

std::tuple<at::Tensor, at::Tensor> ps_roi_align_symint(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::ps_roi_align", "")
.typed<decltype(ps_roi_align_symint)>();
return op.call(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}

namespace detail {

at::Tensor _ps_roi_align_backward(
Expand Down Expand Up @@ -54,13 +69,43 @@ at::Tensor _ps_roi_align_backward(
width);
}

at::Tensor _ps_roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_ps_roi_align_backward", "")
.typed<decltype(_ps_roi_align_backward_symint)>();
return op.call(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
}

} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"));
"torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"));
"torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"));
}

} // namespace ops
Expand Down
21 changes: 21 additions & 0 deletions torchvision/csrc/ops/ps_roi_align.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align(
int64_t pooled_width,
int64_t sampling_ratio);

VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align_symint(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio);

namespace detail {

at::Tensor _ps_roi_align_backward(
Expand All @@ -29,6 +37,19 @@ at::Tensor _ps_roi_align_backward(
int64_t height,
int64_t width);

at::Tensor _ps_roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width);

} // namespace detail

} // namespace ops
Expand Down