Skip to content

Commit

Permalink
Support 0s in out_size of FractionalMaxPoolNd
Browse files Browse the repository at this point in the history
Fixes pytorch#73624

CUDA implementation was correct :), only CPU had an out of bounds memory access
Pull Request resolved: pytorch#73634
Approved by: jbschlosser
  • Loading branch information
Emilio Castillo authored and pytorchmergebot committed Mar 3, 2022
1 parent bf896a2 commit 3186e36
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
5 changes: 3 additions & 2 deletions aten/src/ATen/native/FractionalMaxPool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ static std::vector<int> fractional_max_pool2d_generate_intervals(
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
}
}
sequence[outputSize - 1] = inputSize - poolSize;

if (outputSize > 0) {
sequence[outputSize - 1] = inputSize - poolSize;
}
return sequence;
}

Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/FractionalMaxPool3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ static std::vector<int> generate_intervals(
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
}
}
sequence[outputSize - 1] = inputSize - poolSize;

if (outputSize > 0) {
sequence[outputSize - 1] = inputSize - poolSize;
}
return sequence;
}

Expand Down Expand Up @@ -238,7 +239,6 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
int64_t inputW,
const at::Tensor& output,
const at::Tensor& indices) {

/* get contiguous input */
auto input = input_.contiguous();

Expand Down
14 changes: 14 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14432,6 +14432,20 @@ def test_FractionalMaxPool3d_zero_batch(self, device):
inp = torch.randn(1, 0, 50, 32, 32, device=device)
mod(inp)

@onlyNativeDeviceTypes
def test_FractionalMaxPool2d_zero_out_size(self, device):
mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1])
inp = torch.rand([16, 50, 32, 32], device=device)
out = mod(inp)
self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device))

@onlyNativeDeviceTypes
def test_FractionalMaxPool3d_zero_out_size(self, device):
mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1])
inp = torch.rand([16, 50, 32, 32], device=device)
out = mod(inp)
self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device))

@onlyNativeDeviceTypes
def test_Unfold_empty(self, device):
inp = torch.randn(0, 3, 3, 4, device=device)
Expand Down

0 comments on commit 3186e36

Please sign in to comment.