From 624407077f11a7e2898ebf4ab9ee29b2cce0ca06 Mon Sep 17 00:00:00 2001 From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Date: Tue, 6 Oct 2020 12:44:56 +0800 Subject: [PATCH] Fix wrappers version comparison (#602) * add version check in wrappers * fix assersion * use digital version for version comparison * fix unit tests * reformat * fall back to compare the first two version * fix unittest * fix unittest * fix unit test * clean unnecessary change --- mmcv/cnn/bricks/wrappers.py | 13 +++++++++---- tests/test_cnn/test_wrappers.py | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mmcv/cnn/bricks/wrappers.py b/mmcv/cnn/bricks/wrappers.py index e525b00c6f..6f9b694d55 100644 --- a/mmcv/cnn/bricks/wrappers.py +++ b/mmcv/cnn/bricks/wrappers.py @@ -12,6 +12,10 @@ from .registry import CONV_LAYERS, UPSAMPLE_LAYERS +# torch.__version__ could be 1.3.1+cu92, we only need the first two +# for comparison +TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) + class NewEmptyTensorOp(torch.autograd.Function): @@ -30,7 +34,7 @@ def backward(ctx, grad): class Conv2d(nn.Conv2d): def forward(self, x): - if x.numel() == 0 and torch.__version__ <= '1.4.0': + if x.numel() == 0 and TORCH_VERSION <= (1, 4): out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, self.padding, self.stride, self.dilation): @@ -53,7 +57,7 @@ def forward(self, x): class ConvTranspose2d(nn.ConvTranspose2d): def forward(self, x): - if x.numel() == 0 and torch.__version__ <= '1.4.0': + if x.numel() == 0 and TORCH_VERSION <= (1, 4): out_shape = [x.shape[0], self.out_channels] for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, self.padding, self.stride, @@ -74,7 +78,7 @@ class MaxPool2d(nn.MaxPool2d): def forward(self, x): # PyTorch 1.6 does not support empty tensor inference yet - if x.numel() == 0 and torch.__version__ <= '1.6.0': + if x.numel() == 0 and TORCH_VERSION <= (1, 6): out_shape = list(x.shape[:2]) for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), _pair(self.padding), _pair(self.stride), @@ -91,7 +95,8 @@ def forward(self, x): class Linear(torch.nn.Linear): def forward(self, x): - if x.numel() == 0: + # empty tensor forward of Linear layer is supported in Pytorch 1.6 + if x.numel() == 0 and TORCH_VERSION <= (1, 5): out_shape = [x.shape[0], self.out_features] empty = NewEmptyTensorOp.apply(x, out_shape) if self.training: diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index 067cb6465b..755970c6ad 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -169,7 +169,7 @@ def test_linear(): wrapper(x_empty) -@patch('torch.__version__', '1.6.1') +@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 7)) def test_nn_op_forward_called(): for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: @@ -191,7 +191,7 @@ def test_nn_op_forward_called(): x_empty = torch.randn(0, 3) wrapper = Linear(3, 3) wrapper(x_empty) - nn_module_forward.assert_not_called() + nn_module_forward.assert_called_with(x_empty) # non-randn input x_normal = torch.randn(1, 3)