@@ -1225,6 +1225,101 @@ def cudnn_batch_norm_backward(
1225
1225
)
1226
1226
1227
1227
1228
+ @register_decomposition (aten ._adaptive_avg_pool2d , disable_meta = True )
1229
+ @pw_cast_for_opmath
1230
+ def adaptive_avg_pool2d (input : Tensor , output_size : Tuple [int , int ]):
1231
+ # Preconditions
1232
+ device = input .device
1233
+ shape = input .shape
1234
+ ndim = len (shape )
1235
+ utils .check (
1236
+ ndim in (3 , 4 ),
1237
+ lambda : f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got { ndim } " ,
1238
+ )
1239
+ for d in input .shape [- 2 :]:
1240
+ utils .check (
1241
+ d != 0 ,
1242
+ lambda : "adaptive_avg_pool2d(): Expected input to have non-zero size for "
1243
+ f"non-batch dimensions, but input has shape { tuple (shape )} ." ,
1244
+ )
1245
+
1246
+ # Optimisation (we should also do this in the kernel implementation)
1247
+ if shape [- 2 ] % output_size [- 2 ] == 0 and shape [- 1 ] % output_size [- 1 ] == 0 :
1248
+ stride = tuple (i // o for i , o in zip (shape [- 2 :], output_size ))
1249
+ kernel = tuple (
1250
+ i - (o - 1 ) * s for i , o , s in zip (shape [- 2 :], output_size , stride )
1251
+ )
1252
+ return torch .nn .functional .avg_pool2d (input , kernel , stride )
1253
+
1254
+ def start_index (a , b , c ):
1255
+ return (a * c ) // b
1256
+
1257
+ def end_index (a , b , c ):
1258
+ return (((a + 1 ) * c ) / b ).ceil ().to (a .dtype )
1259
+
1260
+ # Let's assume the reduction we want to apply is to sum all the elements (averaging from this is easy)
1261
+ # Even more, let's assume that we want to just do the 1d case.
1262
+ # The 2d case is recovered by applying the 1d case along two dimensions
1263
+ # The issue here is that we may want to sum segments of different sizes.
1264
+ # What we do is to get the largest segment, and select all the elements from the initial points
1265
+ # up to the max length. Then we zero out the elements that we picked up and were not necessary if there were any such elements
1266
+ # If all the elements have the same length, we compute the average already, otherwise, we return
1267
+ # the sizes of each window, to compute the sizes of the rectrangles at the end.
1268
+ # This function should recover the efficiency of avg_pool2d if the shape does not need the dynamic window shape
1269
+
1270
+ def adaptive_avg_pool1d (x , dim , out_size ):
1271
+ assert dim == - 2 or dim == - 1
1272
+ in_size = x .size (dim )
1273
+
1274
+ orange = torch .arange (out_size , device = device )
1275
+ i0 = start_index (orange , out_size , in_size )
1276
+ # Let length = end_index - start_index, i.e. the length of the pooling kernels
1277
+ # length.max() can be computed analytically as follows:
1278
+ maxlength = in_size // out_size + 1
1279
+ in_size_mod = in_size % out_size
1280
+ # adaptive = True iff there are kernels with different lengths
1281
+ adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0 )
1282
+ if adaptive :
1283
+ maxlength += 1
1284
+ elif in_size_mod == 0 :
1285
+ maxlength -= 1
1286
+
1287
+ range_max = torch .arange (maxlength , device = device )
1288
+ idx = i0 .unsqueeze (- 1 ) + range_max
1289
+ if adaptive :
1290
+ # Need to clamp to avoid accesing out-of-bounds memory
1291
+ idx = idx .clamp (max = in_size - 1 )
1292
+ adv_idx_pad = tuple (slice (None ) for _ in range (dim + ndim ))
1293
+ vals = x [(* adv_idx_pad , idx )]
1294
+
1295
+ if adaptive :
1296
+ i1 = end_index (orange , out_size , in_size )
1297
+ length = i1 - i0
1298
+ # zero-out the things we didn't really want to select
1299
+ assert dim < 0
1300
+ mask = _unsqueeze_to_dim (range_max >= length .unsqueeze (- 1 ), - dim + 1 )
1301
+ vals = torch .masked_fill (vals , mask , 0.0 )
1302
+
1303
+ # Compute the length of each window
1304
+ div = _unsqueeze_to_dim (length , - dim )
1305
+ return vals .sum (dim ), div
1306
+ else :
1307
+ # No need to return div as we have already divided by the mean
1308
+ return vals .mean (dim ), None
1309
+
1310
+ out , div1 = adaptive_avg_pool1d (input , - 1 , output_size [- 1 ])
1311
+ out , div2 = adaptive_avg_pool1d (out , - 2 , output_size [- 2 ])
1312
+ # Filter the Nones
1313
+ divs = tuple (div for div in (div1 , div2 ) if div is not None )
1314
+ # prod(divs) does not work because it accumulates with *=
1315
+ if len (divs ) == 0 :
1316
+ return out
1317
+ elif len (divs ) == 1 :
1318
+ return out / divs [0 ]
1319
+ else : # len(divs) == 2
1320
+ return out / (divs [0 ] * divs [1 ])
1321
+
1322
+
1228
1323
def _squeeze_multiple (self : Tensor , dims : List [int ]) -> Tensor :
1229
1324
ndim = self .dim ()
1230
1325
wrapped_dims = utils .canonicalize_dims (ndim , dims )
0 commit comments