Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] BaseModel & BaseDataPreprocessor to method to be consistent with torch.nn.Module #783

Merged
merged 4 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix BaseModel to method to be consistent with torch.nn.Module
  • Loading branch information
C1rN09 committed Dec 2, 2022
commit 6eeb85f0d95227816411ee00a4181208a8c29230
20 changes: 11 additions & 9 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,23 +177,25 @@ def parse_losses(

return loss, log_vars # type: ignore

def to(self,
device: Optional[Union[int, str, torch.device]] = None,
*args,
**kwargs) -> nn.Module:
def to(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
additionally.

Args:
device (int, str or torch.device, optional): the desired device
of the parameters and buffers in this module.

Returns:
nn.Module: The model itself.
"""
if 'device' in kwargs:
device = kwargs['device']
elif args:
try:
device = torch.device(args[0])
except TypeError:
device = None
else:
device = None
if device is not None:
self._set_device(torch.device(device))
return super().to(device)
return super().to(*args, **kwargs)

def cuda(
self,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_model/test_base_model/test_base_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import unittest
from unittest import TestCase

import torch
import torch.nn as nn
from parameterized import parameterized
from torch.optim import SGD

from mmengine.model import BaseDataPreprocessor, BaseModel
from mmengine.optim import OptimWrapper
from mmengine.registry import MODELS
from mmengine.testing import assert_allclose

dtypes_to_test = [torch.float16, torch.float32, torch.float64, torch.half]

cpu_devices = ['cpu', torch.device('cpu')]
cuda_devices = ['cuda', 0, torch.device('cuda')]
devices_to_test = cpu_devices
if torch.cuda.is_available():
devices_to_test += cuda_devices


def list_product(*args):
return list(itertools.product(*args))


@MODELS.register_module()
class CustomDataPreprocessor(BaseDataPreprocessor):
Expand Down Expand Up @@ -158,3 +172,32 @@ def test_to(self):
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cuda'))

@parameterized.expand(list_product(devices_to_test))
def test_to_device(self, device):
model = ToyModel().to(device)
self.assertTrue(
all(p.device.type == torch.device(device).type
for p in model.parameters())
and model.data_preprocessor._device == torch.device(device))

@parameterized.expand(list_product(dtypes_to_test))
def test_to_dtype(self, dtype):
model = ToyModel().to(dtype)
self.assertTrue(all(p.dtype == dtype for p in model.parameters()))

@parameterized.expand(
list_product(devices_to_test, dtypes_to_test,
['args', 'kwargs', 'hybrid']))
def test_to_device_and_dtype(self, device, dtype, mode):
if mode == 'args':
model = ToyModel().to(device, dtype)
elif mode == 'kwargs':
model = ToyModel().to(device=device, dtype=dtype)
elif mode == 'hybrid':
model = ToyModel().to(device, dtype=dtype)
self.assertTrue(
all(p.dtype == dtype for p in model.parameters())
and model.data_preprocessor._device == torch.device(device)
and all(p.device.type == torch.device(device).type
for p in model.parameters()))