-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.aten.conv_tbc conv1d and conv3d #345
Comments
Seems pretty straightforward. Just need to transpose the inputs. According to this issue,
So this would be as simple as |
@_onnx_symbolic("aten::conv_tbc")
@symbolic_helper.parse_args("v", "v", "v", "i")
@_beartype.beartype
def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("conv_tbc", input, weight, bias, pad_i=pad)
else:
# input must have 3 dimensions, see:
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
# input = (time, batch, in_channels)
# weight = (kernel_width, in_channels, out_channels)
# bias = (out_channels,)
input = g.op("Transpose", input, perm_i=[1, 2, 0])
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
return g.op("Transpose", conv, perm_i=[2, 0, 1]) |
It looks like to implement convtbc i need to implement conv1d, and conv3d seems pretty trivial to tack on. |
Support added for these ops. |
No description provided.
The text was updated successfully, but these errors were encountered: