Skip to content

Commit e6bbbb0

Browse files
Elias Ellisonfacebook-github-bot
Elias Ellison
authored andcommitted
Fix interpolate trace (pytorch#18875)
Summary: Fixes pytorch#10654 The issue is that in tracing `.size` returns an int tensor, and when an int tensor is multiplied by a scalar the int dominates and the scalar gets casted 0. Pull Request resolved: pytorch#18875 Differential Revision: D14814441 Pulled By: eellison fbshipit-source-id: a4e96a2698f2fcbf3ec4b2bb4c43a30250f30ad9
1 parent 6084908 commit e6bbbb0

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

test/test_jit.py

+18
Original file line numberDiff line numberDiff line change
@@ -7518,6 +7518,24 @@ def forward(self, x):
75187518
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
75197519
example_outputs=outputs))
75207520

7521+
def test_interpolate_trace(self):
7522+
class test(nn.Module):
7523+
def __init__(self):
7524+
super(test, self).__init__()
7525+
self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1)
7526+
7527+
def forward(self, x):
7528+
y = self.conv(x)
7529+
w = nn.functional.interpolate(y, mode='bilinear', align_corners=False, scale_factor=0.5)
7530+
return w
7531+
7532+
f = test()
7533+
# no failure
7534+
g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),))
7535+
x = torch.zeros(1, 1, 14, 14)
7536+
# constants not baked in
7537+
self.assertEqual(g(x), f(x))
7538+
75217539
def test_trace_nested_datatypes(self):
75227540
@torch.jit.script
75237541
def foo(x):

torch/nn/functional.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -2492,7 +2492,12 @@ def _output_size(dim):
24922492
return size
24932493
scale_factors = _ntuple(dim)(scale_factor)
24942494
# math.floor might return float in py2.7
2495-
return [int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)]
2495+
2496+
# make scale_factor a tensor in tracing so constant doesn't get baked in
2497+
if torch._C._get_tracing_state():
2498+
return [(torch.floor(input.size(i + 2) * torch.tensor(scale_factors[i]))) for i in range(dim)]
2499+
else:
2500+
return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
24962501

24972502
if mode in ('nearest', 'area'):
24982503
if align_corners is not None:

0 commit comments

Comments
 (0)