Skip to content

Commit

Permalink
Remove c/pdist tests from _internal/common_utils.py (pytorch#33409)
Browse files Browse the repository at this point in the history
Summary:
* remove brute_test from `torch/testing/_internal/common_utils.py`
* add these tests as internal tests to `test_torch.py`

CC ailzhang
Pull Request resolved: pytorch#33409

Differential Revision: D19951729

Pulled By: ailzhang

fbshipit-source-id: b1126aaf26fa64a0f17cbb582dc8038b79cfe3eb
  • Loading branch information
ptrblck authored and facebook-github-bot committed Feb 19, 2020
1 parent 60339a3 commit 1e3664b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 59 deletions.
73 changes: 56 additions & 17 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from torch.testing._internal.common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
TEST_LIBROSA, TEST_WITH_ROCM, run_tests, skipIfNoLapack, suppress_warnings, \
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \
IS_SANDCASTLE, load_tests, pdist_single, brute_cdist, slowTest, \
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm
IS_SANDCASTLE, load_tests, slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, \
BytesIOContext, skipIfRocm
from multiprocessing.reduction import ForkingPickler
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \
Expand Down Expand Up @@ -9686,6 +9686,13 @@ def test_cdist_empty(self, device):
y = torch.randn((0, 0), device=device)
self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))

def _brute_cdist(self, x, y, p=2):
r1 = x.shape[-2]
r2 = y.shape[-2]
if r1 == 0 or r2 == 0:
return torch.empty(r1, r2, device=x.device)
return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)

def test_cdist_norm(self, device):
for r1 in [3, 4, 5, 6]:
for m in [2, 3, 4, 10]:
Expand All @@ -9696,11 +9703,11 @@ def test_cdist_norm(self, device):
if p == 2:
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(torch.allclose(expected, actual, rtol=0, atol=0.02))
else:
actual = torch.cdist(x, y, p=p)
expected = brute_cdist(x, y, p=p)
expected = self._brute_cdist(x, y, p=p)
self.assertTrue(torch.allclose(expected, actual))

def test_cdist_norm_batch(self, device):
Expand All @@ -9713,51 +9720,51 @@ def test_cdist_norm_batch(self, device):
if p == 2:
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(torch.allclose(expected, actual, rtol=0, atol=0.02))
else:
actual = torch.cdist(x, y, p=p)
expected = brute_cdist(x, y, p=p)
expected = self._brute_cdist(x, y, p=p)
self.assertTrue(torch.allclose(expected, actual))

def test_cdist_large(self, device):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(1000, 10, device=device)
y = torch.randn(1000, 10, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(torch.allclose(expected, actual))

def test_cdist_large_batch(self, device):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(4, 3, 1000, 10, device=device)
y = torch.randn(4, 3, 1000, 10, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(torch.allclose(expected, actual))

def test_cdist_non_contiguous(self, device):
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(5, 7, device=device).transpose(-1, -2)
y = torch.randn(5, 3, device=device).transpose(-1, -2)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))

x = torch.randn(7, 5, device=device)
y = torch.randn(5, 3, device=device).t()
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))

x = torch.randn(5, 7, device=device).t()
y = torch.randn(3, 5, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertTrue(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))
Expand All @@ -9767,23 +9774,23 @@ def test_cdist_non_contiguous_batch(self, device):
x = torch.randn(4, 3, 2, 5, 7, device=device).transpose(-1, -2)
y = torch.randn(4, 3, 2, 5, 3, device=device).transpose(-1, -2)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))

x = torch.randn(7, 2, 7, 5, device=device)
y = torch.randn(7, 2, 5, 3, device=device).transpose(-1, -2)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))

x = torch.randn(4, 5, 7, device=device).transpose(-1, -2)
y = torch.randn(4, 3, 5, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = brute_cdist(x, y, p=2)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertTrue(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))
Expand Down Expand Up @@ -11096,24 +11103,56 @@ def test_nonzero_non_diff(self, device):
nz = x.nonzero()
self.assertFalse(nz.requires_grad)

def _brute_pdist(self, inp, p=2):
"""Computes the same as torch.pdist using primitives"""
n = inp.shape[-2]
k = n * (n - 1) // 2
if k == 0:
# torch complains about empty indices
return torch.empty(inp.shape[:-2] + (0,), dtype=inp.dtype, device=inp.device)
square = torch.norm(inp[..., None, :] - inp[..., None, :, :], p=p, dim=-1)
unroll = square.view(square.shape[:-2] + (n * n,))
inds = torch.ones(k, dtype=torch.int)
inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int)
return unroll[..., inds.cumsum(0)]

def _pdist_single(self, shape, device, p, dtype, trans, grad_check=False):
x = torch.randn(shape, dtype=dtype, device=device)
if trans:
x.transpose_(-2, -1)
if grad_check:
x.requires_grad_()
y = x.detach().clone().requires_grad_()
else:
y = x
actual = torch.pdist(x, p=p)
expected = self._brute_pdist(y, p=p)
self.assertEqual(expected.shape, actual.shape)
self.assertTrue(torch.allclose(expected, actual))
if grad_check and expected.size() != torch.Size([0]):
g0 = torch.rand_like(actual)
actual.backward(g0)
expected.backward(g0)
self.assertTrue(torch.allclose(x.grad, y.grad))

def test_pdist_norm_forward(self, device):
for shape in [(4, 5), (3, 2), (2, 1), (1500, 1)]:
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
for trans in [False, True]:
for dtype in [torch.float32, torch.float64]:
pdist_single(self, shape, device, p, dtype, trans, grad_check=False)
self._pdist_single(shape, device, p, dtype, trans, grad_check=False)

# do a simplified comparison with big inputs, see:
# https://github.com/pytorch/pytorch/issues/15511
for dtype in [torch.float32, torch.float64]:
pdist_single(self, (1000, 2), device, 2, dtype, trans=False, grad_check=False)
self._pdist_single((1000, 2), device, 2, dtype, trans=False, grad_check=False)

@skipIfRocm
def test_pdist_norm_backward(self, device):
for shape in [(4, 5), (3, 2), (2, 1), (1500, 1)]:
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
for trans in [False, True]:
pdist_single(self, shape, device, p, torch.float64, trans, grad_check=True)
self._pdist_single(shape, device, p, torch.float64, trans, grad_check=True)

@skipIfRocm
def test_pdist_norm_large(self, device):
Expand Down
42 changes: 0 additions & 42 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,48 +1305,6 @@ def random_matrix(rows, columns, *batch_dims, **kwargs):
return u.matmul(s.expand(batch_dims + (rows, columns)).matmul(v.transpose(-2, -1)))


def brute_pdist(inp, p=2):
"""Computes the same as torch.pdist using primitives"""
n = inp.shape[-2]
k = n * (n - 1) // 2
if k == 0:
# torch complains about empty indices
return torch.empty(inp.shape[:-2] + (0,), dtype=inp.dtype, device=inp.device)
square = torch.norm(inp[..., None, :] - inp[..., None, :, :], p=p, dim=-1)
unroll = square.view(square.shape[:-2] + (n * n,))
inds = torch.ones(k, dtype=torch.int)
inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int)
return unroll[..., inds.cumsum(0)]


def pdist_single(self, shape, device, p, dtype, trans, grad_check=False):
x = torch.randn(shape, dtype=dtype, device=device)
if trans:
x.transpose_(-2, -1)
if grad_check:
x.requires_grad_()
y = x.detach().clone().requires_grad_()
else:
y = x
actual = torch.pdist(x, p=p)
expected = brute_pdist(y, p=p)
self.assertEqual(expected.shape, actual.shape)
self.assertTrue(torch.allclose(expected, actual))
if grad_check and expected.size() != torch.Size([0]):
g0 = torch.rand_like(actual)
actual.backward(g0)
expected.backward(g0)
self.assertTrue(torch.allclose(x.grad, y.grad))


def brute_cdist(x, y, p=2):
r1 = x.shape[-2]
r2 = y.shape[-2]
if r1 == 0 or r2 == 0:
return torch.empty(r1, r2, device=x.device)
return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)


def do_test_dtypes(self, dtypes, layout, device):
for dtype in dtypes:
if dtype != torch.float16:
Expand Down

0 comments on commit 1e3664b

Please sign in to comment.