@@ -97,6 +97,92 @@ def test_padded_to_list(self):
97
97
split_size = torch .randint (1 , K , size = (N ,)).tolist ()
98
98
struct_utils .padded_to_list (x , split_size )
99
99
100
+ def test_padded_to_packed (self ):
101
+ device = torch .device ("cuda:0" )
102
+ N = 5
103
+ K = 20
104
+ ndim = 2
105
+ dims = [K ] * ndim
106
+ x = torch .rand ([N ] + dims , device = device )
107
+
108
+ # Case 1: no split_size or pad_value provided
109
+ # Check output is just the flattened input.
110
+ x_packed = struct_utils .padded_to_packed (x )
111
+ self .assertTrue (x_packed .shape == (x .shape [0 ] * x .shape [1 ], x .shape [2 ]))
112
+ self .assertClose (x_packed , x .reshape (- 1 , K ))
113
+
114
+ # Case 2: pad_value is provided.
115
+ # Check each section of the packed tensor matches the
116
+ # corresponding unpadded elements of the padded tensor.
117
+ # Check that only rows where all the values are padded
118
+ # are removed in the conversion to packed.
119
+ pad_value = - 1
120
+ x_list = []
121
+ split_size = []
122
+ for _ in range (N ):
123
+ dim = torch .randint (K , size = (1 ,)).item ()
124
+ # Add some random values in the input which are the same as the pad_value.
125
+ # These should not be filtered out.
126
+ x_list .append (
127
+ torch .randint (
128
+ low = pad_value , high = 10 , size = (dim , K ), device = device
129
+ )
130
+ )
131
+ split_size .append (dim )
132
+ x_padded = struct_utils .list_to_padded (x_list , pad_value = pad_value )
133
+ x_packed = struct_utils .padded_to_packed (x_padded , pad_value = pad_value )
134
+ curr = 0
135
+ for i in range (N ):
136
+ self .assertClose (
137
+ x_packed [curr : curr + split_size [i ], ...], x_list [i ]
138
+ )
139
+ self .assertClose (torch .cat (x_list ), x_packed )
140
+ curr += split_size [i ]
141
+
142
+ # Case 3: split_size is provided.
143
+ # Check each section of the packed tensor matches the corresponding
144
+ # unpadded elements.
145
+ x_packed = struct_utils .padded_to_packed (
146
+ x_padded , split_size = split_size
147
+ )
148
+ curr = 0
149
+ for i in range (N ):
150
+ self .assertClose (
151
+ x_packed [curr : curr + split_size [i ], ...], x_list [i ]
152
+ )
153
+ self .assertClose (torch .cat (x_list ), x_packed )
154
+ curr += split_size [i ]
155
+
156
+ # Case 4: split_size of the wrong shape is provided.
157
+ # Raise an error.
158
+ split_size = torch .randint (1 , K , size = (2 * N ,)).view (N , 2 ).unbind (0 )
159
+ with self .assertRaisesRegex (ValueError , "1-dimensional" ):
160
+ x_packed = struct_utils .padded_to_packed (
161
+ x_padded , split_size = split_size
162
+ )
163
+
164
+ split_size = torch .randint (1 , K , size = (2 * N ,)).view (N * 2 ).tolist ()
165
+ with self .assertRaisesRegex (
166
+ ValueError , "same length as inputs first dimension"
167
+ ):
168
+ x_packed = struct_utils .padded_to_packed (
169
+ x_padded , split_size = split_size
170
+ )
171
+
172
+ # Case 5: both pad_value and split_size are provided.
173
+ # Raise an error.
174
+ with self .assertRaisesRegex (ValueError , "Only one of" ):
175
+ x_packed = struct_utils .padded_to_packed (
176
+ x_padded , split_size = split_size , pad_value = - 1
177
+ )
178
+
179
+ # Case 6: Input has more than 3 dims.
180
+ # Raise an error.
181
+ with self .assertRaisesRegex (ValueError , "Supports only" ):
182
+ x = torch .rand ((N , K , K , K , K ), device = device )
183
+ split_size = torch .randint (1 , K , size = (N ,)).tolist ()
184
+ struct_utils .padded_to_list (x , split_size )
185
+
100
186
def test_list_to_packed (self ):
101
187
device = torch .device ("cuda:0" )
102
188
N = 5
0 commit comments