From bd6791382f80ecfc6afc3c5d5544e55ecaaba804 Mon Sep 17 00:00:00 2001 From: Qian Zhao <112053249+C1rN09@users.noreply.github.com> Date: Mon, 5 Dec 2022 18:04:47 +0800 Subject: [PATCH] [Fix] BaseModel & BaseDataPreprocessor `to` method to be consistent with torch.nn.Module (#783) * fix BaseModel `to` method to be consistent with torch.nn.Module * fix data_preprocessor as well * fix docstring alignment * fix docstring alignment --- mmengine/model/base_model/base_model.py | 20 +++------ .../model/base_model/data_preprocessor.py | 13 +++--- .../test_base_model/test_base_model.py | 43 +++++++++++++++++++ 3 files changed, 55 insertions(+), 21 deletions(-) diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index 525a2f0664..f9316506d8 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -155,9 +155,9 @@ def parse_losses( Returns: tuple[Tensor, dict]: There are two elements. The first is the - loss tensor passed to optim_wrapper which may be a weighted sum of - all losses, and the second is log_vars which will be sent to the - logger. + loss tensor passed to optim_wrapper which may be a weighted sum + of all losses, and the second is log_vars which will be sent to + the logger. """ log_vars = [] for loss_name, loss_value in losses.items(): @@ -177,23 +177,17 @@ def parse_losses( return loss, log_vars # type: ignore - def to(self, - device: Optional[Union[int, str, torch.device]] = None, - *args, - **kwargs) -> nn.Module: + def to(self, *args, **kwargs) -> nn.Module: """Overrides this method to call :meth:`BaseDataPreprocessor.to` additionally. - Args: - device (int, str or torch.device, optional): the desired device - of the parameters and buffers in this module. - Returns: nn.Module: The model itself. """ + device = torch._C._nn._parse_to(*args, **kwargs)[0] if device is not None: self._set_device(torch.device(device)) - return super().to(device) + return super().to(*args, **kwargs) def cuda( self, @@ -244,7 +238,7 @@ def _set_device(self, device: torch.device) -> None: Args: device (torch.device): the desired device of the parameters and - buffers in this module. + buffers in this module. """ def apply_fn(module): diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index d7d27ec907..17d70e067e 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -84,19 +84,16 @@ def forward(self, data: dict, training: bool = False) -> Union[dict, list]: def device(self): return self._device - def to(self, device: Optional[Union[int, torch.device]], *args, - **kwargs) -> nn.Module: + def to(self, *args, **kwargs) -> nn.Module: """Overrides this method to set the :attr:`device` - Args: - device (int or torch.device, optional): The desired device of the - parameters and buffers in this module. - Returns: nn.Module: The model itself. """ - self._device = torch.device(device) - return super().to(device) + device = torch._C._nn._parse_to(*args, **kwargs)[0] + if device is not None: + self._device = torch.device(device) + return super().to(*args, **kwargs) def cuda(self, *args, **kwargs) -> nn.Module: """Overrides this method to set the :attr:`device` diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py index 22df50ec0c..21356af771 100644 --- a/tests/test_model/test_base_model/test_base_model.py +++ b/tests/test_model/test_base_model/test_base_model.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import itertools import unittest from unittest import TestCase import torch import torch.nn as nn +from parameterized import parameterized from torch.optim import SGD from mmengine.model import BaseDataPreprocessor, BaseModel @@ -11,6 +13,18 @@ from mmengine.registry import MODELS from mmengine.testing import assert_allclose +dtypes_to_test = [torch.float16, torch.float32, torch.float64, torch.half] + +cpu_devices = ['cpu', torch.device('cpu')] +cuda_devices = ['cuda', 0, torch.device('cuda')] +devices_to_test = cpu_devices +if torch.cuda.is_available(): + devices_to_test += cuda_devices + + +def list_product(*args): + return list(itertools.product(*args)) + @MODELS.register_module() class CustomDataPreprocessor(BaseDataPreprocessor): @@ -158,3 +172,32 @@ def test_to(self): self.assertEqual(model.data_preprocessor._device, torch.device('cuda')) self.assertEqual(model.toy_model.data_preprocessor._device, torch.device('cuda')) + + @parameterized.expand(list_product(devices_to_test)) + def test_to_device(self, device): + model = ToyModel().to(device) + self.assertTrue( + all(p.device.type == torch.device(device).type + for p in model.parameters()) + and model.data_preprocessor._device == torch.device(device)) + + @parameterized.expand(list_product(dtypes_to_test)) + def test_to_dtype(self, dtype): + model = ToyModel().to(dtype) + self.assertTrue(all(p.dtype == dtype for p in model.parameters())) + + @parameterized.expand( + list_product(devices_to_test, dtypes_to_test, + ['args', 'kwargs', 'hybrid'])) + def test_to_device_and_dtype(self, device, dtype, mode): + if mode == 'args': + model = ToyModel().to(device, dtype) + elif mode == 'kwargs': + model = ToyModel().to(device=device, dtype=dtype) + elif mode == 'hybrid': + model = ToyModel().to(device, dtype=dtype) + self.assertTrue( + all(p.dtype == dtype for p in model.parameters()) + and model.data_preprocessor._device == torch.device(device) + and all(p.device.type == torch.device(device).type + for p in model.parameters()))