Skip to content

Commit d1dd240

Browse files
HAOCHENYEzhouzaida
andauthored
[Fix] Fix BaseDataPreprocessor.cast_data cound not handle string data (#602)
* [Fix] Fix cound not handle string data * Minor refine * Refine type hint Refine type hint * fix as comment * Minor refine * Update mmengine/model/base_model/data_preprocessor.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
1 parent 1bf5c0c commit d1dd240

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

mmengine/model/base_model/data_preprocessor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from mmengine.utils import is_list_of
1212
from ..utils import stack_batch
1313

14-
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list]
14+
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
15+
None]
1516

1617

1718
@MODELS.register_module()
@@ -48,17 +49,20 @@ def cast_data(self, data: CastData) -> CastData:
4849
"""
4950
if isinstance(data, Mapping):
5051
return {key: self.cast_data(data[key]) for key in data}
52+
elif isinstance(data, (str, bytes)) or data is None:
53+
return data
5154
elif isinstance(data, tuple) and hasattr(data, '_fields'):
5255
# namedtuple
53-
return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable
56+
return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
5457
elif isinstance(data, Sequence):
55-
return [self.cast_data(sample) for sample in data]
56-
elif isinstance(data, torch.Tensor):
57-
return data.to(self.device, non_blocking=self._non_blocking)
58-
elif isinstance(data, BaseDataElement):
58+
return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable
59+
elif isinstance(data, (torch.Tensor, BaseDataElement)):
5960
return data.to(self.device, non_blocking=self._non_blocking)
6061
else:
61-
return data
62+
raise TypeError(
63+
'`BaseDataPreprocessor.cast_data`: batch data must contain '
64+
'tensors, numpy arrays, numbers, dicts or lists, but '
65+
f'found {type(data)}')
6266

6367
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
6468
"""Preprocesses the data into the model input format.

tests/test_model/test_base_model/test_data_preprocessor.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def test_forward(self):
2828
label1 = torch.randn(1)
2929
label2 = torch.randn(1)
3030

31+
# Test with dict of batch inputs and batch data samples
3132
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
32-
3333
output = base_data_preprocessor(data)
3434
batch_inputs, batch_labels = output['inputs'], output['data_sample']
3535
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
@@ -41,40 +41,54 @@ def test_forward(self):
4141
assert_allclose(label2, batch_labels[1])
4242

4343
# Test with tuple of batch inputs and batch data samples
44-
data = dict(
45-
inputs=torch.stack([input1, input2]), data_sample=[label1, label2])
46-
output = base_data_preprocessor(data)['inputs']
44+
data = (torch.stack([input1, input2]), (label1, label2))
45+
batch_inputs, batch_labels = base_data_preprocessor(data)
46+
self.assertTrue(torch.is_floating_point(batch_inputs))
47+
self.assertEqual(batch_inputs[0].shape, (1, 3, 5))
48+
self.assertEqual(batch_inputs[1].shape, (1, 3, 5))
4749
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
4850

4951
# Test cuda forward
5052
if torch.cuda.is_available():
5153
# Test with list of data samples.
54+
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
5255
base_data_preprocessor = base_data_preprocessor.cuda()
5356
output = base_data_preprocessor(data)
5457
batch_inputs, batch_labels = output['inputs'], output[
5558
'data_sample']
56-
self.assertTrue(torch.is_floating_point(batch_inputs))
57-
self.assertEqual(batch_inputs.device.type, 'cuda')
59+
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
60+
self.assertEqual(batch_inputs[0].device.type, 'cuda')
5861

62+
# Fallback to test with cpu.
5963
base_data_preprocessor = base_data_preprocessor.cpu()
6064
output = base_data_preprocessor(data)
6165
batch_inputs, batch_labels = output['inputs'], output[
6266
'data_sample']
63-
self.assertTrue(torch.is_floating_point(batch_inputs))
64-
self.assertEqual(batch_inputs.device.type, 'cpu')
67+
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
68+
self.assertEqual(batch_inputs[0].device.type, 'cpu')
6569

70+
# Test `base_data_preprocessor` can be moved to cuda again.
6671
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
6772
output = base_data_preprocessor(data)
6873
batch_inputs, batch_labels = output['inputs'], output[
6974
'data_sample']
70-
self.assertTrue(torch.is_floating_point(batch_inputs))
71-
self.assertEqual(batch_inputs.device.type, 'cuda')
75+
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
76+
self.assertEqual(batch_inputs[0].device.type, 'cuda')
7277

7378
# device of `base_data_preprocessor` is cuda, output should be
7479
# cuda tensor.
75-
self.assertEqual(batch_inputs.device.type, 'cuda')
80+
self.assertEqual(batch_inputs[0].device.type, 'cuda')
7681
self.assertEqual(batch_labels[0].device.type, 'cuda')
7782

83+
# Test forward with string value
84+
data = dict(string='abc')
85+
base_data_preprocessor(data)
86+
87+
with self.assertRaisesRegex(TypeError,
88+
'`BaseDataPreprocessor.cast_data`:'):
89+
data = dict(string=object())
90+
base_data_preprocessor(data)
91+
7892

7993
class TestImgDataPreprocessor(TestBaseDataPreprocessor):
8094

0 commit comments

Comments
 (0)