Skip to content

Commit 725c2f6

Browse files
cccclaifacebook-github-bot
authored andcommitted
Patch the _is_conv_node function (pytorch#153749)
Summary: X-link: pytorch/ao#2223 torch.ops.aten.conv2d.padding is also conv2d node Test Plan: CI Reviewed By: andrewor14, jerryzh168 Differential Revision: D74898941
1 parent 4277907 commit 725c2f6

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

torch/ao/quantization/pt2e/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,11 @@ def _is_conv_node(n: Node):
167167
"""
168168
return n.op == "call_function" and n.target in [
169169
torch.ops.aten.conv1d.default,
170+
torch.ops.aten.conv1d.padding,
170171
torch.ops.aten.conv2d.default,
172+
torch.ops.aten.conv2d.padding,
173+
torch.ops.aten.conv3d.default,
174+
torch.ops.aten.conv3d.padding,
171175
]
172176

173177

torch/testing/_internal/common_quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3184,11 +3184,11 @@ def forward(self, x):
31843184
return x
31853185

31863186
class ConvWithBNRelu(torch.nn.Module):
3187-
def __init__(self, relu, dim=2, bn=True, bias=True):
3187+
def __init__(self, relu, dim=2, bn=True, bias=True, padding=0):
31883188
super().__init__()
3189-
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d}
3190-
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
3191-
self.conv = convs[dim](3, 3, 3, bias=bias)
3189+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
3190+
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
3191+
self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding)
31923192

31933193
if bn:
31943194
self.bn = bns[dim](3)

0 commit comments

Comments
 (0)