Closed
Description
Issue:
Expand does not support symbolic shapes
Steps to reproduce:
class MatMul(nn.Module):
def __init__(self):
super().__init__()
self.input = nn.Parameter(torch.randn(3,1,3,2))
def forward(self, other):
return torch.matmul(self.input, other)
if __name__ == "__main__":
inputs = torch.randn(2,3,3)
matmul = MatMul()
y = matmul(inputs)
Notes:
The above graph generates a trace like this
%arg0 : [#users=4] = placeholder[target=arg0]
%sym_size : [#users=4] = call_function[target=torch.ops.aten.sym_size](args = (%arg0, 0), kwargs = {})
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%expand_default : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_param_constant0, [%sym_size, 2, 2, 3]), kwargs = {})
%mul : [#users=2] = call_function[target=operator.mul](args = (%sym_size, 2), kwargs = {})
%reshape : [#users=1] = call_function[target=torch.ops.aten.reshape](args = (%expand_default, [%mul, 2, 3]), kwargs = {})
%sym_size_1 : [#users=2] = call_function[target=torch.ops.aten.sym_size](args = (%arg0, 2), kwargs = {})
%sym_size_2 : [#users=3] = call_function[target=torch.ops.aten.sym_size](args = (%arg0, 3), kwargs = {})
%expand_default_1 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg0, [%sym_size, 2, %sym_size_1, %sym_size_2]), kwargs = {}
)
%reshape_1 : [#users=1] = call_function[target=torch.ops.aten.reshape](args = (%expand_default_1, [%mul, %sym_size_1, %sym_size_2]), kwargs = {})
%bmm_default : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape, %reshape_1), kwargs = {})
%reshape_2 : [#users=1] = call_function[target=torch.ops.aten.reshape](args = (%bmm_default, [%sym_size, 2, 2, %sym_size_2]), kwargs = {})
return [reshape_2]
The expand operation in the above gets input in the form which leads to the error (<tensorrt.tensorrt.ITensor object at 0x7fadb39eeb30>, 2, 2, 3)