Skip to content

🐛 [Bug- FX] Matrix multiplication FX dynamo test failures #1827

Closed
@apbose

Description

@apbose

Issue:
Expand does not support symbolic shapes

Steps to reproduce:

class MatMul(nn.Module):
            def __init__(self):
                super().__init__()            
                self.input = nn.Parameter(torch.randn(3,1,3,2))
            def forward(self, other):
                return torch.matmul(self.input, other)
            
if __name__ == "__main__":
    inputs = torch.randn(2,3,3)
    matmul = MatMul()
    y = matmul(inputs)


Notes:
The above graph generates a trace like this

%arg0 : [#users=4] = placeholder[target=arg0]
   %sym_size : [#users=4] = call_function[target=torch.ops.aten.sym_size](args = (%arg0, 0), kwargs = {})
   %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
   %expand_default : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_param_constant0, [%sym_size, 2, 2, 3]), kwargs = {})
   %mul : [#users=2] = call_function[target=operator.mul](args = (%sym_size, 2), kwargs = {})
   %reshape : [#users=1] = call_function[target=torch.ops.aten.reshape](args = (%expand_default, [%mul, 2, 3]), kwargs = {})
   %sym_size_1 : [#users=2] = call_function[target=torch.ops.aten.sym_size](args = (%arg0, 2), kwargs = {})
   %sym_size_2 : [#users=3] = call_function[target=torch.ops.aten.sym_size](args = (%arg0, 3), kwargs = {})
   %expand_default_1 : [#users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg0, [%sym_size, 2, %sym_size_1, %sym_size_2]), kwargs = {}
)
   %reshape_1 : [#users=1] = call_function[target=torch.ops.aten.reshape](args = (%expand_default_1, [%mul, %sym_size_1, %sym_size_2]), kwargs = {})
   %bmm_default : [#users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape, %reshape_1), kwargs = {})
   %reshape_2 : [#users=1] = call_function[target=torch.ops.aten.reshape](args = (%bmm_default, [%sym_size, 2, 2, %sym_size_2]), kwargs = {})
   return [reshape_2]

The expand operation in the above gets input in the form which leads to the error (<tensorrt.tensorrt.ITensor object at 0x7fadb39eeb30>, 2, 2, 3)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions