Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use int64 in pdist kernel to handle batches >= 46342 #30583 #31593

Closed
wants to merge 12 commits into from

Conversation

ptrblck
Copy link
Collaborator

@ptrblck ptrblck commented Dec 24, 2019

Currently torch.pdist yields an illegal CUDA memory access for batch sizes >= 46342 as reported by @ssnl in #30583.
Thanks for the minimal code reproduction, btw! ;)

Reason for this bug:
The calculation if i in the pdist_kerne_cuda_impl might overflow, if a tensor with a batch size >= 46342 is passed to torch.pdist.

Detailed description:

  • result is resizes as n * (n - 1) / 2 = 1073767311 (line of code)
  • grid is initialized as result.numel() (line of code)
  • k is assigned to the blockIdx.x as an int32 (line of code)
  • i is calculated using 2 * k >= 2147534622 (line of code), which overflows, since 2147534622 > 2147483647 (int32_max).

Using const int64_t k = blockIdx.x; would solve the illegal memory access. This seems also be done for cdist_kernel_cuda_impl.

However, we might expect a slowdown, so I've timed the current PyTorch master vs. this PR:
(tested with x = torch.randn(x.size(0), 128) on a V100)

x.size(0) int32 idx int64 idx slowdown
50000 - 4.4460 -
25000 1.02522 1.10869 7.53%
12500 0.25182 0.27277 7.68%
6250 0.06291 0.06817 7.72%
3125 0.01573 0.01704 7.69%
1562 0.00393 0.00426 7.75%

While checking the backward kernel, it seems I'm triggering another error with a size limit of

x = torch.randn(1449, 1, device='cuda', requires_grad=True)
out = torch.pdist(x)
out.mean().backward()
> RuntimeError: CUDA error: invalid configuration argument

, while [<=1448, 1] works.

I'll take another look at this issue. Let me know, if the potential fix should go into this PR or if I should open a new issue.

CC @ngimel, @csarofeen

@kostmo
Copy link
Member

kostmo commented Dec 24, 2019

💊 CircleCI build failures summary and remediations

As of commit 1fc79be:

  • 1/1 failures introduced in this PR

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🕵️ 1 new failure recognized by patterns

The following build failures do not appear to be due to upstream breakage:

See CircleCI build pytorch_xla_linux_xenial_py3_6_clang7_test (1/1)

Step: "Test" (full log | pattern match details)

Feb 09 04:08:57 caused by: Connection refused (os error 111)
Feb 09 04:08:57 +++ eval 'extract_trap_cmd ' 
Feb 09 04:08:57 ++++ extract_trap_cmd 
Feb 09 04:08:57 ++++ printf '%s\n' '' 
Feb 09 04:08:57 +++ printf '%s\n' cleanup 
Feb 09 04:08:57 ++ trap -- ' 
Feb 09 04:08:57 cleanup' EXIT 
Feb 09 04:08:57 ++ which sccache 
Feb 09 04:08:57 ++ sccache --stop-server 
Feb 09 04:08:57 Stopping sccache server... 
Feb 09 04:08:57 error: couldn't connect to server 
Feb 09 04:08:57 caused by: Connection refused (os error 111) 
Feb 09 04:08:57 ++ true 
Feb 09 04:08:57 ++ rm /var/lib/jenkins/sccache_error.log 
Feb 09 04:08:57 ++ SCCACHE_ERROR_LOG=/var/lib/jenkins/sccache_error.log 
Feb 09 04:08:57 ++ SCCACHE_IDLE_TIMEOUT=1200 
Feb 09 04:08:57 ++ RUST_LOG=sccache::server=error 
Feb 09 04:08:57 ++ sccache --start-server 
Feb 09 04:08:57 Starting sccache server... 
Feb 09 04:08:58 ++ sccache --zero-stats 
Feb 09 04:08:58 Compile requests                 0 
Feb 09 04:08:58 Compile requests executed        0 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 52 times.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Dec 24, 2019

It seems the backward error comes from a wrong CUDA launch config triggered by grid_y >= 65535 (for an input tensor of [1449, 1] it will be 65568, which is invalid according to the programming guide) in this line of code.

A similar issue is thrown for cdist (but for another size, need to triage this bug also).

I'll open a new issue and try to fix the launch configs there.

EDIT: might also be related to #24345

EDIT2: PR for cdist #31167 which might be also applicable for pdist.

@ngimel
Copy link
Collaborator

ngimel commented Dec 24, 2019

@ptrblck 8% perf regression is acceptable. Please fix backward in this PR also. We can leave cdist separate for now.

@ngimel
Copy link
Collaborator

ngimel commented Jan 7, 2020

And add tests please. I don't understand how just swapping the grid and block sizes would work without corresponding changes in the kernel itself.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Jan 7, 2020

I've added tests for the failed use cases and kept the sizes "reasonably" small, but let me know, if I should expand the test cases.

Benchmarking for the backward pass on V100-SXM2 32GB (time in ms/iter):

Using input = torch.randn(size, 128, device='cuda'):

x.size(0) before PR after PR
5600 - 78.8309
2800 - 19.9385
1400 6.0119 5.0028
700 1.5210 1.2706
350 0.4104 0.3429
175 0.1673 0.1648

Using input = torch.randn(size, 1, device='cuda'):

x.size(0) before PR after PR
50000 - 4541.2841
25000 - 1133.9128
12500 - 283.3532
6250 - 70.8232
3125 - 17.7395
1562 - 4.4551
781 1.3270 1.1294
390 0.3595 0.3087
195 0.1651 0.2168

Code used for benchmarking:

import torch
import torch.nn as nn
import time

nb_iters = 10
sizes = [int(50000/2**i) for i in range(10)]

for size in sizes:
    x = torch.randn(size, 1, device='cuda', requires_grad=True)
    # warmup
    for _ in range(nb_iters):
        out = torch.pdist(x)
        out.mean().backward()
    #print(torch.cuda.memory_allocated()/1024**3)

    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(nb_iters):
        out = torch.pdist(x)
        out.mean().backward()
    torch.cuda.synchronize()
    t1 = time.time()
    print('size {}, time {:.4f}ms/iter'.format(size, 1000*(t1 - t0)/nb_iters))

@ngimel
Copy link
Collaborator

ngimel commented Jan 7, 2020

  1. Test failure are real
  2. tests are smoke tests not testing correctness
  3. and this is concerning, because I don't see how just flipping grid dimensions without changing the actual kernels can work.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Jan 7, 2020

Sorry for the confusion, but kernel changes weren't in the commit. 😕
I'll add correctness tests and commit the kernel changes.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Jan 8, 2020

The forward test for [50000, 1] uses approx. 23GB when comparing with brute_pdist.
Just the smoke test without the comparison ~4.66GB.
Should we fall back to the smoke test for this shape? I could decorate this test with LARGE_TENSOR and try to run it on our CI.

I've added the gradient check for the other, smaller shapes as well.
Let me know, if I should remove them and just keep it for backward pass for [1500, 1].

@ptrblck ptrblck changed the title Use int64 in pdist kernel to handle batches >= 46342 #30583 [WIP] Use int64 in pdist kernel to handle batches >= 46342 #30583 Jan 8, 2020
@ptrblck ptrblck changed the title [WIP] Use int64 in pdist kernel to handle batches >= 46342 #30583 Use int64 in pdist kernel to handle batches >= 46342 #30583 Jan 8, 2020
@zou3519 zou3519 requested a review from ngimel January 9, 2020 19:10
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 9, 2020
@ptrblck
Copy link
Collaborator Author

ptrblck commented Jan 24, 2020

@pytorchbot retest this please

@ngimel
Copy link
Collaborator

ngimel commented Jan 24, 2020

at least rocm failure is real, have not looked at other ones.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Jan 27, 2020

@pytorchbot retest this please

@ngimel
Copy link
Collaborator

ngimel commented Feb 3, 2020

@ptrblck can you please split pdist test into 2, one testing forward, and another testing backward that would be completely skipped on ROCm using @skipIfRocm decorator? We try to deprecate TEST_WITH_ROCM thing.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Feb 4, 2020

@ngimel Sure, I'll split the tests and skip the backward for Rocm.
However, Rocm now also seems to fail in the forward pass for a tensor of [50000, 1] (initial commit to fix the forward pass size limitation by using const int64_t k = blockIdx.x;) with

00:38:18 ======================================================================
00:38:18 FAIL: test_pdist_norm_cuda (__main__.TestTorchDeviceTypeCUDA)
00:38:18 ----------------------------------------------------------------------
00:38:18 Traceback (most recent call last):
00:38:18   File "/var/lib/jenkins/.local/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 681, in wrapper
00:38:18     method(*args, **kwargs)
00:38:18   File "/var/lib/jenkins/.local/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 180, in instantiated_test
00:38:18     return test(self, device_arg)
00:38:18   File "test_torch.py", line 11002, in test_pdist_norm
00:38:18     self.assertTrue(torch.allclose(expected_cpu, actual_gpu.cpu()))
00:38:18 AssertionError: False is not true

This test checks the GPU result for p=2 with the CPU result.

Should I skip this test as well for rocm or wait for a review?

@ngimel
Copy link
Collaborator

ngimel commented Feb 6, 2020

@ptrblck please skip the forward test also and file an issue for ROCm

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The things I'd like to see in the test

  1. no TEST_WITH_ROCM, only decorators
  2. separate test for the large size you are adding
  3. disable ROCm tests as needed (as long as you are not breaking anything, and those are the tests that you are adding), but file an issue for ROCm detailing the failures.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Feb 9, 2020

I've removed the TEST_WITH_ROCM usage and added the decorators.
test_pdist_norm_forward is applied for all devices, while test_pdist_norm_backward and test_pdist_norm_large is skipped for rocm.
I've also moved pdist_single to common_utils (where brute_pdist is also located) to avoid code duplication.

I'll create a new issue with the information about the failing rocm tests.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in a64d0ff.

ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
…ytorch#31593)

Summary:
Currently `torch.pdist` yields an illegal CUDA memory access for batch sizes >= 46342 as reported by SsnL in pytorch#30583.
Thanks for the minimal code reproduction, btw! ;)

Reason for this bug:
The calculation if `i` in the [`pdist_kerne_cuda_impl`](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L112) might overflow, if a tensor with a `batch size >= 46342` is passed to `torch.pdist`.

Detailed description:
* `result` is resizes as ` n * (n - 1) / 2 = 1073767311` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/Distance.cpp#L140))
* `grid` is initialized as `result.numel()` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L246))
* `k` is assigned to the `blockIdx.x` as an `int32` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L108))
* `i` is calculated using `2 * k >= 2147534622` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L112)), which overflows, since `2147534622 > 2147483647 (int32_max)`.

Using `const int64_t k = blockIdx.x;` would solve the illegal memory access. This seems also be done for [`cdist_kernel_cuda_impl`](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L198-L201).

However, we might expect a slowdown, so I've timed the current PyTorch master vs. this PR:
(tested with `x = torch.randn(x.size(0), 128)` on a V100)

 |x.size(0) | int32 idx | int64 idx | slowdown |
 |----------|-----------|-----------|----------|
| 50000 | -              | 4.4460 | - |
| 25000 | 1.02522 | 1.10869 | 7.53% |
| 12500 | 0.25182 | 0.27277 | 7.68% |
| 6250 | 0.06291 | 0.06817 | 7.72% |
| 3125 | 0.01573 | 0.01704 | 7.69% |
| 1562 | 0.00393 | 0.00426 | 7.75% |

While checking the backward kernel, it seems I'm triggering another error with a size limit of
```python
x = torch.randn(1449, 1, device='cuda', requires_grad=True)
out = torch.pdist(x)
out.mean().backward()
> RuntimeError: CUDA error: invalid configuration argument
```
, while `[<=1448, 1]` works.

I'll take another look at this issue. Let me know, if the potential fix should go into this PR or if I should open a new issue.

CC ngimel, csarofeen
Pull Request resolved: pytorch#31593

Differential Revision: D19825571

Pulled By: ngimel

fbshipit-source-id: ace9ccab49f3cf0ce894cdb6daef0795e2e8ec03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants