-
Notifications
You must be signed in to change notification settings - Fork 610
Description
When revamping the shape library, I noticed that the shape function for torch.aten.native_layer_norm was not implemented correctly for the 2nd and 3rd results. I tried fixing the buggy shape function, but it seems that the TorchToLinalg lowering broke if I did that (the shape function actually calculates the wrong rank).
I left some comments (including a correct version of the code) in shape_lib_gen.py about the issue and expected behavior.
torch-mlir/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py
Line 766 in a5fe0cf
| # TODO: Fix shape function (see body). |
Our testing didn't catch this because for some reason the test was only checking the first result, rather than all 3:
| class NativeLayerNormModule(torch.nn.Module): |
To fix this bug, you should be able to uncomment the correct code linked above in shape_lib_gen.py, and then debug what TorchToLinalg is doing wrong.