Skip to content

Commit

Permalink
Use int64 in pdist kernel to handle batches >= 46342 pytorch#30583 (p…
Browse files Browse the repository at this point in the history
…ytorch#31593)

Summary:
Currently `torch.pdist` yields an illegal CUDA memory access for batch sizes >= 46342 as reported by SsnL in pytorch#30583.
Thanks for the minimal code reproduction, btw! ;)

Reason for this bug:
The calculation if `i` in the [`pdist_kerne_cuda_impl`](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L112) might overflow, if a tensor with a `batch size >= 46342` is passed to `torch.pdist`.

Detailed description:
* `result` is resizes as ` n * (n - 1) / 2 = 1073767311` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/Distance.cpp#L140))
* `grid` is initialized as `result.numel()` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L246))
* `k` is assigned to the `blockIdx.x` as an `int32` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L108))
* `i` is calculated using `2 * k >= 2147534622` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L112)), which overflows, since `2147534622 > 2147483647 (int32_max)`.

Using `const int64_t k = blockIdx.x;` would solve the illegal memory access. This seems also be done for [`cdist_kernel_cuda_impl`](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L198-L201).

However, we might expect a slowdown, so I've timed the current PyTorch master vs. this PR:
(tested with `x = torch.randn(x.size(0), 128)` on a V100)

 |x.size(0) | int32 idx | int64 idx | slowdown |
 |----------|-----------|-----------|----------|
| 50000 | -              | 4.4460 | - |
| 25000 | 1.02522 | 1.10869 | 7.53% |
| 12500 | 0.25182 | 0.27277 | 7.68% |
| 6250 | 0.06291 | 0.06817 | 7.72% |
| 3125 | 0.01573 | 0.01704 | 7.69% |
| 1562 | 0.00393 | 0.00426 | 7.75% |

While checking the backward kernel, it seems I'm triggering another error with a size limit of
```python
x = torch.randn(1449, 1, device='cuda', requires_grad=True)
out = torch.pdist(x)
out.mean().backward()
> RuntimeError: CUDA error: invalid configuration argument
```
, while `[<=1448, 1]` works.

I'll take another look at this issue. Let me know, if the potential fix should go into this PR or if I should open a new issue.

CC ngimel, csarofeen
Pull Request resolved: pytorch#31593

Differential Revision: D19825571

Pulled By: ngimel

fbshipit-source-id: ace9ccab49f3cf0ce894cdb6daef0795e2e8ec03
  • Loading branch information
ptrblck authored and facebook-github-bot committed Feb 11, 2020
1 parent 367488b commit a64d0ff
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 25 deletions.
19 changes: 9 additions & 10 deletions aten/src/ATen/native/cuda/DistanceKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ __device__ static inline scalar_t reduce_agg(scalar_t agg) {
template <typename scalar_t, typename F>
__global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p,
const double n2, const double n2_squared_minus_1) {
const int k = blockIdx.x;
const int64_t k = blockIdx.x;
const int stride = blockDim.x;

// The -1 accounts for floating point truncation issues
Expand Down Expand Up @@ -162,9 +162,9 @@ __global__ static void cdist_backward_kernel_cuda_impl(scalar_t * buffer, const
template <typename scalar_t, typename F>
__global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, int64_t gs, const int64_t n, const int64_t m, const int64_t combs, const scalar_t p,
const double n2, const double n2_squared_minus_1) {
const int k = blockIdx.y * blockDim.y + threadIdx.y;
const int init = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = blockDim.x * gridDim.x;
const int64_t k = blockIdx.x * blockDim.x + threadIdx.x;
const int init = blockIdx.y * blockDim.y + threadIdx.y;
const int stride = blockDim.y * gridDim.y;

if (k >= combs) {
return;
Expand Down Expand Up @@ -276,13 +276,12 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor

const int64_t n = result.size(0);
int64_t m = self.size(1);
const int block_x = 64;
const int block_x = 16;
// NB: be careful with changing block_y; as it's currently written, grid_y is limited to be 2^16.
// From binary search, block_y of 16 gives us max pdist dim0 of 1449,
// block_y of 4 gives us max pdist dim0 of 725.
const int block_y = 16;
const int grid_x = (m + block_x * 8 - 1) / (block_x * 8);
const int grid_y = (dist.numel() + block_y - 1) / block_y;
// block_y of 64 gives us max pdist dim1 of 2**24
const int block_y = 64;
const int grid_x = (dist.numel() + block_x - 1) / block_x;
const int grid_y = (m + block_y * 8 - 1) / (block_y * 8);
const dim3 grid(grid_x, grid_y);
const dim3 block(block_x, block_y);
// https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do
Expand Down
39 changes: 24 additions & 15 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from torch.testing._internal.common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
TEST_LIBROSA, TEST_WITH_ROCM, run_tests, skipIfNoLapack, suppress_warnings, \
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext
IS_SANDCASTLE, load_tests, pdist_single, brute_cdist, slowTest, \
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm
from multiprocessing.reduction import ForkingPickler
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \
Expand Down Expand Up @@ -10960,26 +10960,35 @@ def test_nonzero_non_diff(self, device):
nz = x.nonzero()
self.assertFalse(nz.requires_grad)

def test_pdist_norm(self, device):
def test_pdist_single(shape, device, p, dtype, trans):
x = torch.randn(shape, dtype=dtype, device=device)
if trans:
x.transpose_(-2, -1)
actual = torch.pdist(x, p=p)
expected = brute_pdist(x, p=p)
self.assertEqual(expected.shape, actual.shape)
self.assertTrue(torch.allclose(expected, actual))

for shape in [(4, 5), (3, 2), (2, 1)]:
def test_pdist_norm_forward(self, device):
for shape in [(4, 5), (3, 2), (2, 1), (1500, 1)]:
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
for trans in [False, True]:
for dtype in [torch.float32, torch.float64]:
test_pdist_single(shape, device, p, dtype, trans)
pdist_single(self, shape, device, p, dtype, trans, grad_check=False)

# do a simplified comparison with big inputs, see:
# https://github.com/pytorch/pytorch/issues/15511
for dtype in [torch.float32, torch.float64]:
test_pdist_single((1000, 2), device, 2, dtype, False)
pdist_single(self, (1000, 2), device, 2, dtype, trans=False, grad_check=False)

@skipIfRocm
def test_pdist_norm_backward(self, device):
for shape in [(4, 5), (3, 2), (2, 1), (1500, 1)]:
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
for trans in [False, True]:
pdist_single(self, shape, device, p, torch.float64, trans, grad_check=True)

@skipIfRocm
def test_pdist_norm_large(self, device):
# use dim0>=46342 for forward, see:
# https://github.com/pytorch/pytorch/issues/30583
# Compare output using GPU with the CPU implementation, as brute_pdist uses too much memory
if 'cuda' in device:
x = torch.randn(50000, 1, dtype=torch.float32)
expected_cpu = torch.pdist(x, p=2)
actual_gpu = torch.pdist(x.to(device), p=2)
self.assertTrue(torch.allclose(expected_cpu, actual_gpu.cpu()))

def test_atan2(self, device):
def _test_atan2_with_size(size, device):
Expand Down
20 changes: 20 additions & 0 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,26 @@ def brute_pdist(inp, p=2):
return unroll[..., inds.cumsum(0)]


def pdist_single(self, shape, device, p, dtype, trans, grad_check=False):
x = torch.randn(shape, dtype=dtype, device=device)
if trans:
x.transpose_(-2, -1)
if grad_check:
x.requires_grad_()
y = x.detach().clone().requires_grad_()
else:
y = x
actual = torch.pdist(x, p=p)
expected = brute_pdist(y, p=p)
self.assertEqual(expected.shape, actual.shape)
self.assertTrue(torch.allclose(expected, actual))
if grad_check and expected.size() != torch.Size([0]):
g0 = torch.rand_like(actual)
actual.backward(g0)
expected.backward(g0)
self.assertTrue(torch.allclose(x.grad, y.grad))


def brute_cdist(x, y, p=2):
r1 = x.shape[-2]
r2 = y.shape[-2]
Expand Down

0 comments on commit a64d0ff

Please sign in to comment.