@@ -92,34 +92,36 @@ def convert_interpolate_trt7(ctx):
9292
9393
9494class Interpolate (torch .nn .Module ):
95- def __init__ (self , size , mode , align_corners ):
95+ def __init__ (self , size = None , scale_factor = None , mode = None , align_corners = None ):
9696 super (Interpolate , self ).__init__ ()
97+ ## Use either size or scale factor.
9798 self .size = size
99+ self .scale_factor = scale_factor
98100 self .mode = mode
99101 self .align_corners = align_corners
100102
101103 def forward (self , x ):
102- return F .interpolate (x , self .size , mode = self .mode , align_corners = self .align_corners )
104+ return F .interpolate (x , size = self .size , scale_factor = self . scale_factor , mode = self .mode , align_corners = self .align_corners )
103105
104106
105107@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 10 , 112 , 112 )], enabled = trt_version () < '7.1' and has_interpolate_plugin ())
106108def test_interpolate_nearest ():
107- return Interpolate ((224 , 224 ), 'nearest' , None )
109+ return Interpolate (size = (224 , 224 ), mode = 'nearest' , align_corners = None )
108110
109111
110112@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 10 , 112 , 112 )], enabled = trt_version () < '7.1' and has_interpolate_plugin ())
111113def test_interpolate_bilinear ():
112- return Interpolate ((224 , 224 ), 'bilinear' , False )
114+ return Interpolate (size = (224 , 224 ), mode = 'bilinear' , align_corners = False )
113115
114116
115117@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 10 , 112 , 112 )], enabled = trt_version () < '7.1' and has_interpolate_plugin ())
116118def test_interpolate_bicubic ():
117- return Interpolate ((224 , 224 ), 'bicubic' , False )
119+ return Interpolate (size = (224 , 224 ), mode = 'bicubic' ,align_corners = False )
118120
119121
120122@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 10 , 112 , 112 )], enabled = trt_version () < '7.1' and has_interpolate_plugin ())
121123def test_interpolate_area ():
122- return Interpolate ((56 , 56 ), 'area' , None )
124+ return Interpolate (size = (56 , 56 ), mode = 'area' ,align_corners = None )
123125
124126@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 , 10 , 112 , 112 )], enabled = trt_version () < '7.1' and has_interpolate_plugin ())
125127def test_upsample_scale_factor2 ():
@@ -135,7 +137,11 @@ def test_bilinear_mode():
135137
136138@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 ,3 ,12 ,12 )], enabled = trt_version () >= '7.1' )
137139def test_align_corner ():
138- return torch .nn .Upsample (scale_factor = 2 , mode = "bilinear" , align_corners = True )
140+ return torch .nn .Upsample (scale_factor = 2.0 , mode = "bilinear" , align_corners = True )
141+
142+ @add_module_test (torch .float32 , torch .device ('cuda' ), [(1 ,3 ,12 ,12 )], enabled = trt_version () >= '7.1' )
143+ def test_align_corner_functional ():
144+ return Interpolate (scale_factor = 2.0 , mode = "bilinear" , align_corners = True )
139145
140146@add_module_test (torch .float32 , torch .device ('cuda' ), [(1 ,5 ,13 ,13 )], enabled = trt_version () >= '7.1' )
141147def test_bilinear_mode_odd_input_shape ():
0 commit comments