Skip to content
Merged
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c4fc01b
adopt `torch.testing.assert_close` in test suite
pmeier May 20, 2021
bfbe19b
revert some changes
pmeier May 20, 2021
09f86f4
add todo
pmeier May 20, 2021
86402f0
flake8
pmeier May 20, 2021
48d32e6
Hopefully fixed test_functional_tensor
NicolasHug May 20, 2021
15b50e3
hopefully fixed test_ops
NicolasHug May 20, 2021
a54880f
Merge branch 'master' of github.com:pytorch/vision into assert-close
NicolasHug May 20, 2021
61874ac
Fix test_utils
NicolasHug May 20, 2021
30f20a3
revert unwanted changes to test_image
NicolasHug May 20, 2021
3a29ae3
maybe fixed test_transforms
NicolasHug May 20, 2021
d6d73d0
Merge branch 'master' into assert-close
NicolasHug May 20, 2021
863f144
fix test_datasets_video_utils
pmeier May 21, 2021
c8a5afa
fix test_transforms
pmeier May 21, 2021
e697e88
Merge branch 'master' into assert-close
pmeier May 21, 2021
93614f0
flake8
pmeier May 21, 2021
11caf01
Merge branch 'master' of github.com:pytorch/vision into assert-close
NicolasHug May 21, 2021
d7fde8c
Merge branch 'assert-close' of github.com:pmeier/vision into assert-c…
NicolasHug May 21, 2021
0b237c7
use cu102 see if the nightlies are actual nightlies?
NicolasHug May 21, 2021
c2ace86
obviously forgot to call regenerate.py
NicolasHug May 21, 2021
d78226a
not as obvious, reverting
NicolasHug May 21, 2021
bb543a7
Merge branch 'master' into assert-close
NicolasHug May 21, 2021
7507a0c
Merge branch 'master' into assert-close
NicolasHug May 21, 2021
1b4964a
revert everything but ops
NicolasHug May 21, 2021
b6424b3
remove comment and put back shape equality assertions
NicolasHug May 21, 2021
172148c
Merge branch 'master' into assert_close_ops
NicolasHug May 21, 2021
a7402c5
Merge branch 'master' into assert_close_ops
NicolasHug May 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 32 additions & 51 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from common_utils import needs_cuda, cpu_only
from _assert_utils import assert_equal
import math
import unittest
import pytest
Expand Down Expand Up @@ -78,7 +79,8 @@ def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwa
sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs)

tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol))
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
# self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, ))

def _test_backward(self, device, contiguous):
pool_size = 2
Expand Down Expand Up @@ -363,7 +365,7 @@ def make_rois(num_rois=1000):

abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
t_scale = torch.full_like(abs_diff, fill_value=scale)
self.assertTrue(torch.allclose(abs_diff, t_scale, atol=1e-5))
torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)

x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
Expand Down Expand Up @@ -555,7 +557,7 @@ def test_nms_cuda_float16(self):
iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres)
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
assert torch.all(torch.eq(keep32, keep16))
assert_equal(keep32, keep16)

@cpu_only
def test_batched_nms_implementations(self):
Expand All @@ -573,12 +575,13 @@ def test_batched_nms_implementations(self):
keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

err_msg = "The vanilla and the trick implementation yield different nms outputs."
assert torch.allclose(keep_vanilla, keep_trick), err_msg
torch.testing.assert_close(
keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
)

# Also make sure an empty tensor is returned if boxes is empty
empty = torch.empty((0,), dtype=torch.int64)
assert torch.allclose(empty, ops.batched_nms(empty, None, None, None))
torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))


class DeformConvTester(OpTester, unittest.TestCase):
Expand Down Expand Up @@ -690,15 +693,17 @@ def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)

self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected)
)

# no modulation test
res = layer(x, offset)
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)

self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected)
)

# test for wrong sizes
with self.assertRaises(RuntimeError):
Expand Down Expand Up @@ -778,7 +783,7 @@ def test_compare_cpu_cuda_grads(self):
else:
self.assertTrue(init_weight.grad is not None)
res_grads = init_weight.grad.to("cpu")
self.assertTrue(true_cpu_grads.allclose(res_grads))
torch.testing.assert_close(true_cpu_grads, res_grads)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_autocast(self):
Expand Down Expand Up @@ -812,14 +817,14 @@ def test_frozenbatchnorm2d_eps(self):
bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
bn.load_state_dict(state_dict)
# Difference is expected to fall in an acceptable range
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)

# Check computation for eps > 0
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
fbn.load_state_dict(state_dict, strict=False)
bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
bn.load_state_dict(state_dict)
self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6))
torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)

def test_frozenbatchnorm2d_n_arg(self):
"""Ensure a warning is thrown when passing `n` kwarg
Expand Down Expand Up @@ -860,20 +865,11 @@ def test_bbox_same(self):
exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

box_same = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy")
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we shoudn't remove self.assertEqual(exp_xyxy.size(), torch.Size([4, 4])) as it's not the same as comparing exp_xyxy.size() vs box_tensor.size(), so I'll put these back

self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)

box_same = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh")
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)

box_same = ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh")
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)

def test_bbox_xyxy_xywh(self):
# Simple test convert boxes to xywh and back. Make sure they are same.
Expand All @@ -884,15 +880,11 @@ def test_bbox_xyxy_xywh(self):
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)

box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
self.assertEqual(exp_xywh.size(), torch.Size([4, 4]))
self.assertEqual(exp_xywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xywh, exp_xywh)).item()
assert_equal(box_xywh, exp_xywh)

# Reverse conversion
box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
self.assertEqual(box_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(box_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xyxy, box_tensor)).item()
assert_equal(box_xyxy, box_tensor)

def test_bbox_xyxy_cxcywh(self):
# Simple test convert boxes to xywh and back. Make sure they are same.
Expand All @@ -903,15 +895,11 @@ def test_bbox_xyxy_cxcywh(self):
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)

box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4]))
self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item()
assert_equal(box_cxcywh, exp_cxcywh)

# Reverse conversion
box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
self.assertEqual(box_xyxy.size(), torch.Size([4, 4]))
self.assertEqual(box_xyxy.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xyxy, box_tensor)).item()
assert_equal(box_xyxy, box_tensor)

def test_bbox_xywh_cxcywh(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
Expand All @@ -922,15 +910,11 @@ def test_bbox_xywh_cxcywh(self):
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)

box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4]))
self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item()
assert_equal(box_cxcywh, exp_cxcywh)

# Reverse conversion
box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
self.assertEqual(box_xywh.size(), torch.Size([4, 4]))
self.assertEqual(box_xywh.dtype, box_tensor.dtype)
assert torch.all(torch.eq(box_xywh, box_tensor)).item()
assert_equal(box_xywh, box_tensor)

def test_bbox_invalid(self):
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
Expand All @@ -951,19 +935,18 @@ def test_bbox_convert_jit(self):

box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE)
torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE)

box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE)
torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE)


class BoxAreaTester(unittest.TestCase):
def test_box_area(self):
def area_check(box, expected, tolerance=1e-4):
out = ops.box_area(box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)

# Check for int boxes
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
Expand Down Expand Up @@ -991,8 +974,7 @@ class BoxIouTester(unittest.TestCase):
def test_iou(self):
def iou_check(box, expected, tolerance=1e-4):
out = ops.box_iou(box, box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)

# Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]:
Expand All @@ -1013,8 +995,7 @@ class GenBoxIouTester(unittest.TestCase):
def test_gen_iou(self):
def gen_iou_check(box, expected, tolerance=1e-4):
out = ops.generalized_box_iou(box, box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)

# Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]:
Expand Down