Skip to content

Commit

Permalink
Cudnn bn size fix (pytorch#32763)
Browse files Browse the repository at this point in the history
Summary:
Should fix pytorch#29744 by falling back to native batch norm implementation, if cudnn cannot execute the provided shape.

Shape numbers were verified for cudnn 7.6.5.32 with tensor shapes:
```python
# for spatial bn
x = torch.Size([880801, 256, 5])
x = torch.Size([65535, 256, 5])
x = torch.Size([880801, 64, 4, 4])
x = torch.Size([65535, 64, 4, 4])

# for per-act bn
x = torch.Size([131070, 2048])
x = torch.Size([262136, 2048])
```
for `training()` and `eval()` mode using `torch.float32` and `torch.float16`.

I've increased the shape of our current smoke test to, but I can also add all use cases of the support matrix, if wanted.

CC ngimel
Pull Request resolved: pytorch#32763

Differential Revision: D19644328

Pulled By: ngimel

fbshipit-source-id: c2151bf9fe6bac79b8cbc69cff517a4b0b3867aa
  • Loading branch information
ptrblck authored and facebook-github-bot committed Jan 31, 2020
1 parent bcb7c22 commit 0f09720
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,10 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& input.size(0) <= 131070
&& ((input.dim() == 2 && input.size(0) <= 131070 && training) // per-activation, training
|| (input.dim() == 2 && input.size(0) <= 262136 && !training) // per-activation, eval
|| (input.dim() >= 3 && input.size(0) <= 880801 && training) // spatial, training
|| (input.dim() >= 3 && input.size(0) <= 65535 && !training)) //spatial, eval
&& detail::getCUDAHooks().compiledWithCuDNN()
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L);

Expand Down
4 changes: 2 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6262,8 +6262,8 @@ def test_batchnorm_nonaffine_cuda_half_input(self):
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types([torch.float, torch.half])
def test_batchnorm_large_batch(self, dtype=torch.float):
bn = nn.BatchNorm1d(1).to('cuda', dtype)
data = torch.rand(131072, 1, device="cuda", dtype=dtype)
bn = nn.BatchNorm2d(1).to('cuda', dtype)
data = torch.rand(880801, 1, 1, 1, device="cuda", dtype=dtype)
out = bn(data).sum().backward()

def test_batchnorm_raises_error_if_less_than_one_value_per_channel(self):
Expand Down

0 comments on commit 0f09720

Please sign in to comment.