diff --git a/opacus/tests/grad_samples/conv2d_test.py b/opacus/tests/grad_samples/conv2d_test.py index 60661bf2..f99ce563 100644 --- a/opacus/tests/grad_samples/conv2d_test.py +++ b/opacus/tests/grad_samples/conv2d_test.py @@ -19,6 +19,8 @@ import torch import torch.nn as nn from hypothesis import given, settings +from opacus.utils.tensor_utils import unfold2d +from torch.testing import assert_allclose from .common import GradSampleHooks_test, expander, shrinker @@ -68,3 +70,52 @@ def test_conv2d( groups=groups, ) self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4) + + @given( + B=st.integers(1, 4), + C=st.sampled_from([1, 3, 32]), + H=st.integers(11, 17), + W=st.integers(11, 17), + k_w=st.integers(2, 3), + k_h=st.integers(2, 3), + stride_w=st.integers(1, 2), + stride_h=st.integers(1, 2), + pad_h=st.sampled_from([0, 2]), + pad_w=st.sampled_from([0, 2]), + dilation_w=st.integers(1, 3), + dilation_h=st.integers(1, 3), + ) + @settings(deadline=10000) + def test_unfold2d( + self, + B: int, + C: int, + H: int, + W: int, + k_w: int, + k_h: int, + pad_w: int, + pad_h: int, + stride_w: int, + stride_h: int, + dilation_w: int, + dilation_h: int, + ): + X = torch.randn(B, C, H, W) + X_unfold_torch = torch.nn.functional.unfold( + X, + kernel_size=(k_h, k_w), + padding=(pad_h, pad_w), + stride=(stride_w, stride_h), + dilation=(dilation_w, dilation_h), + ) + + X_unfold_opacus = unfold2d( + X, + kernel_size=(k_h, k_w), + padding=(pad_h, pad_w), + stride=(stride_w, stride_h), + dilation=(dilation_w, dilation_h), + ) + + assert_allclose(X_unfold_torch, X_unfold_opacus, atol=0, rtol=0) diff --git a/opacus/utils/tensor_utils.py b/opacus/utils/tensor_utils.py index 192343c7..f04a3e15 100644 --- a/opacus/utils/tensor_utils.py +++ b/opacus/utils/tensor_utils.py @@ -130,7 +130,8 @@ def unfold2d( W_effective = ( W + 2 * padding[1] - (kernel_size[1] + (kernel_size[1] - 1) * (dilation[1] - 1)) ) // stride[1] + 1 - input = F.pad(input, (padding[0], padding[0], padding[1], padding[1])) + # F.pad's first argument is the padding of the *last* dimension + input = F.pad(input, (padding[1], padding[1], padding[0], padding[0])) *shape_pad, H_pad, W_pad = input.shape strides = list(input.stride()) strides = strides[:-2] + [