diff --git a/aten/src/ATen/native/FractionalMaxPool2d.cpp b/aten/src/ATen/native/FractionalMaxPool2d.cpp index b4ea2ec186f245..bb25be4a02e522 100644 --- a/aten/src/ATen/native/FractionalMaxPool2d.cpp +++ b/aten/src/ATen/native/FractionalMaxPool2d.cpp @@ -134,8 +134,9 @@ static std::vector fractional_max_pool2d_generate_intervals( static_cast((i + sample) * alpha) - static_cast(sample * alpha); } } - sequence[outputSize - 1] = inputSize - poolSize; - + if (outputSize > 0) { + sequence[outputSize - 1] = inputSize - poolSize; + } return sequence; } diff --git a/aten/src/ATen/native/FractionalMaxPool3d.cpp b/aten/src/ATen/native/FractionalMaxPool3d.cpp index 757ce7c056913a..8bcb53847271ae 100644 --- a/aten/src/ATen/native/FractionalMaxPool3d.cpp +++ b/aten/src/ATen/native/FractionalMaxPool3d.cpp @@ -106,8 +106,9 @@ static std::vector generate_intervals( static_cast((i + sample) * alpha) - static_cast(sample * alpha); } } - sequence[outputSize - 1] = inputSize - poolSize; - + if (outputSize > 0) { + sequence[outputSize - 1] = inputSize - poolSize; + } return sequence; } @@ -238,7 +239,6 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)( int64_t inputW, const at::Tensor& output, const at::Tensor& indices) { - /* get contiguous input */ auto input = input_.contiguous(); diff --git a/test/test_nn.py b/test/test_nn.py index 4b3198777cdceb..30c4d136e1b907 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -14432,6 +14432,20 @@ def test_FractionalMaxPool3d_zero_batch(self, device): inp = torch.randn(1, 0, 50, 32, 32, device=device) mod(inp) + @onlyNativeDeviceTypes + def test_FractionalMaxPool2d_zero_out_size(self, device): + mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1]) + inp = torch.rand([16, 50, 32, 32], device=device) + out = mod(inp) + self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device)) + + @onlyNativeDeviceTypes + def test_FractionalMaxPool3d_zero_out_size(self, device): + mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1]) + inp = torch.rand([16, 50, 32, 32], device=device) + out = mod(inp) + self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device)) + @onlyNativeDeviceTypes def test_Unfold_empty(self, device): inp = torch.randn(0, 3, 3, 4, device=device)