9
9
10
10
11
11
class TestStructUtils (TestCaseMixin , unittest .TestCase ):
12
+ def setUp (self ) -> None :
13
+ super ().setUp ()
14
+ torch .manual_seed (43 )
15
+
16
+ def _check_list_to_padded_slices (self , x , x_padded , ndim ):
17
+ N = len (x )
18
+ for i in range (N ):
19
+ slices = [i ]
20
+ for dim in range (ndim ):
21
+ if x [i ].nelement () == 0 and x [i ].ndim == 1 :
22
+ slice_ = slice (0 , 0 , 1 )
23
+ else :
24
+ slice_ = slice (0 , x [i ].shape [dim ], 1 )
25
+ slices .append (slice_ )
26
+ if x [i ].nelement () == 0 and x [i ].ndim == 1 :
27
+ x_correct = x [i ].new_zeros (* [[0 ] * ndim ])
28
+ else :
29
+ x_correct = x [i ]
30
+ self .assertClose (x_padded [slices ], x_correct )
31
+
12
32
def test_list_to_padded (self ):
13
33
device = torch .device ("cuda:0" )
14
34
N = 5
15
35
K = 20
16
- ndim = 2
17
- x = []
18
- for _ in range (N ):
19
- dims = torch .randint (K , size = (ndim ,)).tolist ()
20
- x .append (torch .rand (dims , device = device ))
21
- pad_size = [K ] * ndim
22
- x_padded = struct_utils .list_to_padded (
23
- x , pad_size = pad_size , pad_value = 0.0 , equisized = False
24
- )
36
+ for ndim in [1 , 2 , 3 , 4 ]:
37
+ x = []
38
+ for _ in range (N ):
39
+ dims = torch .randint (K , size = (ndim ,)).tolist ()
40
+ x .append (torch .rand (dims , device = device ))
25
41
26
- self . assertEqual ( x_padded . shape [ 1 ], K )
27
- self . assertEqual ( x_padded . shape [ 2 ], K )
28
- for i in range ( N ):
29
- self . assertClose ( x_padded [ i , : x [ i ]. shape [ 0 ], : x [ i ]. shape [ 1 ]], x [ i ])
30
-
31
- # check for no pad size (defaults to max dimension)
32
- x_padded = struct_utils . list_to_padded ( x , pad_value = 0.0 , equisized = False )
33
- max_size0 = max ( y . shape [ 0 ] for y in x )
34
- max_size1 = max ( y . shape [ 1 ] for y in x )
35
- self . assertEqual ( x_padded . shape [ 1 ], max_size0 )
36
- self . assertEqual ( x_padded . shape [ 2 ], max_size1 )
37
- for i in range (N ):
38
- self .assertClose (x_padded [ i , : x [ i ] .shape [0 ], : x [ i ]. shape [ 1 ]], x [ i ] )
42
+ # set 0th element to an empty 1D tensor
43
+ x [ 0 ] = torch . tensor ([ ], dtype = x [ 0 ]. dtype , device = device )
44
+
45
+ # set 1st element to an empty tensor with correct number of dims
46
+ x [ 1 ] = x [ 1 ]. new_zeros ( * [[ 0 ] * ndim ])
47
+
48
+ pad_size = [ K ] * ndim
49
+ x_padded = struct_utils . list_to_padded (
50
+ x , pad_size = pad_size , pad_value = 0.0 , equisized = False
51
+ )
52
+
53
+ for dim in range (ndim ):
54
+ self .assertEqual (x_padded .shape [dim + 1 ], K )
39
55
40
- # check for equisized
41
- x = [torch .rand ((K , 10 ), device = device ) for _ in range (N )]
42
- x_padded = struct_utils .list_to_padded (x , equisized = True )
43
- self .assertClose (x_padded , torch .stack (x , 0 ))
56
+ self ._check_list_to_padded_slices (x , x_padded , ndim )
57
+
58
+ # check for no pad size (defaults to max dimension)
59
+ x_padded = struct_utils .list_to_padded (x , pad_value = 0.0 , equisized = False )
60
+ max_sizes = (
61
+ max (
62
+ (0 if (y .nelement () == 0 and y .ndim == 1 ) else y .shape [dim ])
63
+ for y in x
64
+ )
65
+ for dim in range (ndim )
66
+ )
67
+ for dim , max_size in enumerate (max_sizes ):
68
+ self .assertEqual (x_padded .shape [dim + 1 ], max_size )
69
+
70
+ self ._check_list_to_padded_slices (x , x_padded , ndim )
71
+
72
+ # check for equisized
73
+ x = [torch .rand ((K , * ([10 ] * (ndim - 1 ))), device = device ) for _ in range (N )]
74
+ x_padded = struct_utils .list_to_padded (x , equisized = True )
75
+ self .assertClose (x_padded , torch .stack (x , 0 ))
44
76
45
77
# catch ValueError for invalid dimensions
46
78
with self .assertRaisesRegex (ValueError , "Pad size must" ):
47
- pad_size = [K ] * 4
79
+ pad_size = [K ] * ( ndim + 1 )
48
80
struct_utils .list_to_padded (
49
81
x , pad_size = pad_size , pad_value = 0.0 , equisized = False
50
82
)
@@ -56,7 +88,7 @@ def test_list_to_padded(self):
56
88
dims = torch .randint (K , size = (ndim ,)).tolist ()
57
89
x .append (torch .rand (dims , device = device ))
58
90
pad_size = [K ] * 2
59
- with self .assertRaisesRegex (ValueError , "Supports only " ):
91
+ with self .assertRaisesRegex (ValueError , "Pad size must " ):
60
92
x_padded = struct_utils .list_to_padded (
61
93
x , pad_size = pad_size , pad_value = 0.0 , equisized = False
62
94
)
@@ -66,27 +98,29 @@ def test_padded_to_list(self):
66
98
N = 5
67
99
K = 20
68
100
ndim = 2
69
- dims = [K ] * ndim
70
- x = torch .rand ([N ] + dims , device = device )
71
101
72
- x_list = struct_utils .padded_to_list (x )
73
- for i in range (N ):
74
- self .assertClose (x_list [i ], x [i ])
102
+ for ndim in (2 , 3 , 4 ):
75
103
76
- split_size = torch .randint (1 , K , size = (N ,)).tolist ()
77
- x_list = struct_utils .padded_to_list (x , split_size )
78
- for i in range (N ):
79
- self .assertClose (x_list [i ], x [i , : split_size [i ]])
104
+ dims = [K ] * ndim
105
+ x = torch .rand ([N ] + dims , device = device )
80
106
81
- split_size = torch .randint (1 , K , size = (2 * N ,)).view (N , 2 ).unbind (0 )
82
- x_list = struct_utils .padded_to_list (x , split_size )
83
- for i in range (N ):
84
- self .assertClose (x_list [i ], x [i , : split_size [i ][0 ], : split_size [i ][1 ]])
107
+ x_list = struct_utils .padded_to_list (x )
108
+ for i in range (N ):
109
+ self .assertClose (x_list [i ], x [i ])
85
110
86
- with self .assertRaisesRegex (ValueError , "Supports only" ):
87
- x = torch .rand ((N , K , K , K , K ), device = device )
88
- split_size = torch .randint (1 , K , size = (N ,)).tolist ()
89
- struct_utils .padded_to_list (x , split_size )
111
+ split_size = torch .randint (1 , K , size = (N , ndim )).unbind (0 )
112
+ x_list = struct_utils .padded_to_list (x , split_size )
113
+ for i in range (N ):
114
+ slices = [i ]
115
+ for dim in range (ndim ):
116
+ slices .append (slice (0 , split_size [i ][dim ], 1 ))
117
+ self .assertClose (x_list [i ], x [slices ])
118
+
119
+ # split size is a list of ints
120
+ split_size = [int (z ) for z in torch .randint (1 , K , size = (N ,)).unbind (0 )]
121
+ x_list = struct_utils .padded_to_list (x , split_size )
122
+ for i in range (N ):
123
+ self .assertClose (x_list [i ], x [i ][: split_size [i ]])
90
124
91
125
def test_padded_to_packed (self ):
92
126
device = torch .device ("cuda:0" )
@@ -160,7 +194,7 @@ def test_padded_to_packed(self):
160
194
with self .assertRaisesRegex (ValueError , "Supports only" ):
161
195
x = torch .rand ((N , K , K , K , K ), device = device )
162
196
split_size = torch .randint (1 , K , size = (N ,)).tolist ()
163
- struct_utils .padded_to_list (x , split_size )
197
+ struct_utils .padded_to_packed (x , split_size = split_size )
164
198
165
199
def test_list_to_packed (self ):
166
200
device = torch .device ("cuda:0" )
0 commit comments