Skip to content

Commit be79bf7

Browse files
authored
Merge pull request #57 from mt1871/fix_maxpool_ceilmode
consider ceil_mode of torch.nn.MaxPool2d
2 parents 13ae1fc + 31a4d2d commit be79bf7

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

torch2trt/converters/MaxPool2d.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
23

34

45
@tensorrt_converter('torch.nn.MaxPool2d.forward')
@@ -23,5 +24,19 @@ def convert_MaxPool2d(ctx):
2324
input=input._trt, type=trt.PoolingType.MAX, window_size=kernel_size)
2425
layer.stride = stride
2526
layer.padding = padding
27+
if module.ceil_mode:
28+
layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
2629

27-
output._trt = layer.get_output(0)
30+
output._trt = layer.get_output(0)
31+
32+
33+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)])
34+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)])
35+
def test_MaxPool2d_without_ceil_mode():
36+
return torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
37+
38+
39+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)])
40+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)])
41+
def test_MaxPool2d_with_ceil_mode():
42+
return torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)

0 commit comments

Comments
 (0)