@@ -28,8 +28,8 @@ def test_forward(self):
28
28
label1 = torch .randn (1 )
29
29
label2 = torch .randn (1 )
30
30
31
+ # Test with dict of batch inputs and batch data samples
31
32
data = dict (inputs = [input1 , input2 ], data_sample = [label1 , label2 ])
32
-
33
33
output = base_data_preprocessor (data )
34
34
batch_inputs , batch_labels = output ['inputs' ], output ['data_sample' ]
35
35
self .assertTrue (torch .is_floating_point (batch_inputs [0 ]))
@@ -41,40 +41,54 @@ def test_forward(self):
41
41
assert_allclose (label2 , batch_labels [1 ])
42
42
43
43
# Test with tuple of batch inputs and batch data samples
44
- data = dict (
45
- inputs = torch .stack ([input1 , input2 ]), data_sample = [label1 , label2 ])
46
- output = base_data_preprocessor (data )['inputs' ]
44
+ data = (torch .stack ([input1 , input2 ]), (label1 , label2 ))
45
+ batch_inputs , batch_labels = base_data_preprocessor (data )
46
+ self .assertTrue (torch .is_floating_point (batch_inputs ))
47
+ self .assertEqual (batch_inputs [0 ].shape , (1 , 3 , 5 ))
48
+ self .assertEqual (batch_inputs [1 ].shape , (1 , 3 , 5 ))
47
49
self .assertTrue (torch .is_floating_point (batch_inputs [0 ]))
48
50
49
51
# Test cuda forward
50
52
if torch .cuda .is_available ():
51
53
# Test with list of data samples.
54
+ data = dict (inputs = [input1 , input2 ], data_sample = [label1 , label2 ])
52
55
base_data_preprocessor = base_data_preprocessor .cuda ()
53
56
output = base_data_preprocessor (data )
54
57
batch_inputs , batch_labels = output ['inputs' ], output [
55
58
'data_sample' ]
56
- self .assertTrue (torch .is_floating_point (batch_inputs ))
57
- self .assertEqual (batch_inputs .device .type , 'cuda' )
59
+ self .assertTrue (torch .is_floating_point (batch_inputs [ 0 ] ))
60
+ self .assertEqual (batch_inputs [ 0 ] .device .type , 'cuda' )
58
61
62
+ # Fallback to test with cpu.
59
63
base_data_preprocessor = base_data_preprocessor .cpu ()
60
64
output = base_data_preprocessor (data )
61
65
batch_inputs , batch_labels = output ['inputs' ], output [
62
66
'data_sample' ]
63
- self .assertTrue (torch .is_floating_point (batch_inputs ))
64
- self .assertEqual (batch_inputs .device .type , 'cpu' )
67
+ self .assertTrue (torch .is_floating_point (batch_inputs [ 0 ] ))
68
+ self .assertEqual (batch_inputs [ 0 ] .device .type , 'cpu' )
65
69
70
+ # Test `base_data_preprocessor` can be moved to cuda again.
66
71
base_data_preprocessor = base_data_preprocessor .to ('cuda:0' )
67
72
output = base_data_preprocessor (data )
68
73
batch_inputs , batch_labels = output ['inputs' ], output [
69
74
'data_sample' ]
70
- self .assertTrue (torch .is_floating_point (batch_inputs ))
71
- self .assertEqual (batch_inputs .device .type , 'cuda' )
75
+ self .assertTrue (torch .is_floating_point (batch_inputs [ 0 ] ))
76
+ self .assertEqual (batch_inputs [ 0 ] .device .type , 'cuda' )
72
77
73
78
# device of `base_data_preprocessor` is cuda, output should be
74
79
# cuda tensor.
75
- self .assertEqual (batch_inputs .device .type , 'cuda' )
80
+ self .assertEqual (batch_inputs [ 0 ] .device .type , 'cuda' )
76
81
self .assertEqual (batch_labels [0 ].device .type , 'cuda' )
77
82
83
+ # Test forward with string value
84
+ data = dict (string = 'abc' )
85
+ base_data_preprocessor (data )
86
+
87
+ with self .assertRaisesRegex (TypeError ,
88
+ '`BaseDataPreprocessor.cast_data`:' ):
89
+ data = dict (string = object ())
90
+ base_data_preprocessor (data )
91
+
78
92
79
93
class TestImgDataPreprocessor (TestBaseDataPreprocessor ):
80
94
0 commit comments