-
Notifications
You must be signed in to change notification settings - Fork 23.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use int64 in pdist kernel to handle batches >= 46342 #30583 #31593
Conversation
💊 CircleCI build failures summary and remediationsAs of commit 1fc79be:
Detailed failure analysisOne may explore the probable reasons each build failed interactively on the Dr. CI website. 🕵️ 1 new failure recognized by patternsThe following build failures do not appear to be due to upstream breakage:
|
It seems the backward error comes from a wrong CUDA launch config triggered by A similar issue is thrown for I'll open a new issue and try to fix the launch configs there. EDIT: might also be related to #24345 EDIT2: PR for |
@ptrblck 8% perf regression is acceptable. Please fix backward in this PR also. We can leave cdist separate for now. |
And add tests please. I don't understand how just swapping the grid and block sizes would work without corresponding changes in the kernel itself. |
I've added tests for the failed use cases and kept the sizes "reasonably" small, but let me know, if I should expand the test cases. Benchmarking for the backward pass on V100-SXM2 32GB (time in ms/iter): Using
Using
Code used for benchmarking: import torch
import torch.nn as nn
import time
nb_iters = 10
sizes = [int(50000/2**i) for i in range(10)]
for size in sizes:
x = torch.randn(size, 1, device='cuda', requires_grad=True)
# warmup
for _ in range(nb_iters):
out = torch.pdist(x)
out.mean().backward()
#print(torch.cuda.memory_allocated()/1024**3)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
out = torch.pdist(x)
out.mean().backward()
torch.cuda.synchronize()
t1 = time.time()
print('size {}, time {:.4f}ms/iter'.format(size, 1000*(t1 - t0)/nb_iters)) |
|
Sorry for the confusion, but kernel changes weren't in the commit. 😕 |
The forward test for I've added the gradient check for the other, smaller shapes as well. |
@pytorchbot retest this please |
at least rocm failure is real, have not looked at other ones. |
@pytorchbot retest this please |
@ptrblck can you please split pdist test into 2, one testing forward, and another testing backward that would be completely skipped on ROCm using @skipIfRocm decorator? We try to deprecate TEST_WITH_ROCM thing. |
@ngimel Sure, I'll split the tests and skip the backward for Rocm.
This test checks the GPU result for Should I skip this test as well for rocm or wait for a review? |
@ptrblck please skip the forward test also and file an issue for ROCm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The things I'd like to see in the test
- no TEST_WITH_ROCM, only decorators
- separate test for the large size you are adding
- disable ROCm tests as needed (as long as you are not breaking anything, and those are the tests that you are adding), but file an issue for ROCm detailing the failures.
I've removed the I'll create a new issue with the information about the failing rocm tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…ytorch#31593) Summary: Currently `torch.pdist` yields an illegal CUDA memory access for batch sizes >= 46342 as reported by SsnL in pytorch#30583. Thanks for the minimal code reproduction, btw! ;) Reason for this bug: The calculation if `i` in the [`pdist_kerne_cuda_impl`](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L112) might overflow, if a tensor with a `batch size >= 46342` is passed to `torch.pdist`. Detailed description: * `result` is resizes as ` n * (n - 1) / 2 = 1073767311` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/Distance.cpp#L140)) * `grid` is initialized as `result.numel()` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L246)) * `k` is assigned to the `blockIdx.x` as an `int32` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L108)) * `i` is calculated using `2 * k >= 2147534622` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L112)), which overflows, since `2147534622 > 2147483647 (int32_max)`. Using `const int64_t k = blockIdx.x;` would solve the illegal memory access. This seems also be done for [`cdist_kernel_cuda_impl`](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L198-L201). However, we might expect a slowdown, so I've timed the current PyTorch master vs. this PR: (tested with `x = torch.randn(x.size(0), 128)` on a V100) |x.size(0) | int32 idx | int64 idx | slowdown | |----------|-----------|-----------|----------| | 50000 | - | 4.4460 | - | | 25000 | 1.02522 | 1.10869 | 7.53% | | 12500 | 0.25182 | 0.27277 | 7.68% | | 6250 | 0.06291 | 0.06817 | 7.72% | | 3125 | 0.01573 | 0.01704 | 7.69% | | 1562 | 0.00393 | 0.00426 | 7.75% | While checking the backward kernel, it seems I'm triggering another error with a size limit of ```python x = torch.randn(1449, 1, device='cuda', requires_grad=True) out = torch.pdist(x) out.mean().backward() > RuntimeError: CUDA error: invalid configuration argument ``` , while `[<=1448, 1]` works. I'll take another look at this issue. Let me know, if the potential fix should go into this PR or if I should open a new issue. CC ngimel, csarofeen Pull Request resolved: pytorch#31593 Differential Revision: D19825571 Pulled By: ngimel fbshipit-source-id: ace9ccab49f3cf0ce894cdb6daef0795e2e8ec03
Currently
torch.pdist
yields an illegal CUDA memory access for batch sizes >= 46342 as reported by @ssnl in #30583.Thanks for the minimal code reproduction, btw! ;)
Reason for this bug:
The calculation if
i
in thepdist_kerne_cuda_impl
might overflow, if a tensor with abatch size >= 46342
is passed totorch.pdist
.Detailed description:
result
is resizes asn * (n - 1) / 2 = 1073767311
(line of code)grid
is initialized asresult.numel()
(line of code)k
is assigned to theblockIdx.x
as anint32
(line of code)i
is calculated using2 * k >= 2147534622
(line of code), which overflows, since2147534622 > 2147483647 (int32_max)
.Using
const int64_t k = blockIdx.x;
would solve the illegal memory access. This seems also be done forcdist_kernel_cuda_impl
.However, we might expect a slowdown, so I've timed the current PyTorch master vs. this PR:
(tested with
x = torch.randn(x.size(0), 128)
on a V100)While checking the backward kernel, it seems I'm triggering another error with a size limit of
, while
[<=1448, 1]
works.I'll take another look at this issue. Let me know, if the potential fix should go into this PR or if I should open a new issue.
CC @ngimel, @csarofeen