Skip to content

Commit 9fc02f6

Browse files
lezcanopytorchmergebot
authored andcommitted
Decomposition for adaptive_avg_pool2d (pytorch#84062)
This was already implemented as a lowering in pytorch/torchdynamo#962. I'm putting the idea up here ~(I haven't even run this code, so it surely has *many* issues, but I reckon the general idea should hopefully be alright).~ The tests now pass and I corrected the issues that the first implementation had. Pull Request resolved: pytorch#84062 Approved by: https://github.com/jansel
1 parent 3aae6ff commit 9fc02f6

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

torch/_decomp/decompositions.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,101 @@ def cudnn_batch_norm_backward(
12251225
)
12261226

12271227

1228+
@register_decomposition(aten._adaptive_avg_pool2d, disable_meta=True)
1229+
@pw_cast_for_opmath
1230+
def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
1231+
# Preconditions
1232+
device = input.device
1233+
shape = input.shape
1234+
ndim = len(shape)
1235+
utils.check(
1236+
ndim in (3, 4),
1237+
lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
1238+
)
1239+
for d in input.shape[-2:]:
1240+
utils.check(
1241+
d != 0,
1242+
lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
1243+
f"non-batch dimensions, but input has shape {tuple(shape)}.",
1244+
)
1245+
1246+
# Optimisation (we should also do this in the kernel implementation)
1247+
if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
1248+
stride = tuple(i // o for i, o in zip(shape[-2:], output_size))
1249+
kernel = tuple(
1250+
i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride)
1251+
)
1252+
return torch.nn.functional.avg_pool2d(input, kernel, stride)
1253+
1254+
def start_index(a, b, c):
1255+
return (a * c) // b
1256+
1257+
def end_index(a, b, c):
1258+
return (((a + 1) * c) / b).ceil().to(a.dtype)
1259+
1260+
# Let's assume the reduction we want to apply is to sum all the elements (averaging from this is easy)
1261+
# Even more, let's assume that we want to just do the 1d case.
1262+
# The 2d case is recovered by applying the 1d case along two dimensions
1263+
# The issue here is that we may want to sum segments of different sizes.
1264+
# What we do is to get the largest segment, and select all the elements from the initial points
1265+
# up to the max length. Then we zero out the elements that we picked up and were not necessary if there were any such elements
1266+
# If all the elements have the same length, we compute the average already, otherwise, we return
1267+
# the sizes of each window, to compute the sizes of the rectrangles at the end.
1268+
# This function should recover the efficiency of avg_pool2d if the shape does not need the dynamic window shape
1269+
1270+
def adaptive_avg_pool1d(x, dim, out_size):
1271+
assert dim == -2 or dim == -1
1272+
in_size = x.size(dim)
1273+
1274+
orange = torch.arange(out_size, device=device)
1275+
i0 = start_index(orange, out_size, in_size)
1276+
# Let length = end_index - start_index, i.e. the length of the pooling kernels
1277+
# length.max() can be computed analytically as follows:
1278+
maxlength = in_size // out_size + 1
1279+
in_size_mod = in_size % out_size
1280+
# adaptive = True iff there are kernels with different lengths
1281+
adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
1282+
if adaptive:
1283+
maxlength += 1
1284+
elif in_size_mod == 0:
1285+
maxlength -= 1
1286+
1287+
range_max = torch.arange(maxlength, device=device)
1288+
idx = i0.unsqueeze(-1) + range_max
1289+
if adaptive:
1290+
# Need to clamp to avoid accesing out-of-bounds memory
1291+
idx = idx.clamp(max=in_size - 1)
1292+
adv_idx_pad = tuple(slice(None) for _ in range(dim + ndim))
1293+
vals = x[(*adv_idx_pad, idx)]
1294+
1295+
if adaptive:
1296+
i1 = end_index(orange, out_size, in_size)
1297+
length = i1 - i0
1298+
# zero-out the things we didn't really want to select
1299+
assert dim < 0
1300+
mask = _unsqueeze_to_dim(range_max >= length.unsqueeze(-1), -dim + 1)
1301+
vals = torch.masked_fill(vals, mask, 0.0)
1302+
1303+
# Compute the length of each window
1304+
div = _unsqueeze_to_dim(length, -dim)
1305+
return vals.sum(dim), div
1306+
else:
1307+
# No need to return div as we have already divided by the mean
1308+
return vals.mean(dim), None
1309+
1310+
out, div1 = adaptive_avg_pool1d(input, -1, output_size[-1])
1311+
out, div2 = adaptive_avg_pool1d(out, -2, output_size[-2])
1312+
# Filter the Nones
1313+
divs = tuple(div for div in (div1, div2) if div is not None)
1314+
# prod(divs) does not work because it accumulates with *=
1315+
if len(divs) == 0:
1316+
return out
1317+
elif len(divs) == 1:
1318+
return out / divs[0]
1319+
else: # len(divs) == 2
1320+
return out / (divs[0] * divs[1])
1321+
1322+
12281323
def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
12291324
ndim = self.dim()
12301325
wrapped_dims = utils.canonicalize_dims(ndim, dims)

0 commit comments

Comments
 (0)