From 7dca27dd572176085ce2edac67bd94c8d7c6663b Mon Sep 17 00:00:00 2001 From: Philipp Allgeuer <5592992+pallgeuer@users.noreply.github.com> Date: Fri, 30 Sep 2022 09:01:36 +0200 Subject: [PATCH] [Fix] Fix warning with `torch.meshgrid`. (#860) * Fix warning with torch.meshgrid * Add torch_meshgrid_ij wrapper * Use `digit_version` instead of packaging package. Co-authored-by: mzr1996 --- tests/test_models/test_utils/test_attention.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_utils/test_attention.py b/tests/test_models/test_utils/test_attention.py index 9626f66fecd..cc37d13415d 100644 --- a/tests/test_models/test_utils/test_attention.py +++ b/tests/test_models/test_utils/test_attention.py @@ -1,18 +1,25 @@ # Copyright (c) OpenMMLab. All rights reserved. +from functools import partial from unittest import TestCase from unittest.mock import ANY, MagicMock import pytest import torch +from mmcv.utils import TORCH_VERSION, digit_version from mmcls.models.utils.attention import ShiftWindowMSA, WindowMSA +if digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'): + torch_meshgrid_ij = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid_ij = torch.meshgrid # Uses indexing='ij' by default + def get_relative_position_index(window_size): """Method from original code of Swin-Transformer.""" coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords = torch.stack(torch_meshgrid_ij([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww # 2, Wh*Ww, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]