Skip to content

aten::native_layer_norm 2nd and 3rd results not implemented correctly #665

@silvasean

Description

@silvasean

When revamping the shape library, I noticed that the shape function for torch.aten.native_layer_norm was not implemented correctly for the 2nd and 3rd results. I tried fixing the buggy shape function, but it seems that the TorchToLinalg lowering broke if I did that (the shape function actually calculates the wrong rank).

I left some comments (including a correct version of the code) in shape_lib_gen.py about the issue and expected behavior.

Our testing didn't catch this because for some reason the test was only checking the first result, rather than all 3:

class NativeLayerNormModule(torch.nn.Module):

To fix this bug, you should be able to uncomment the correct code linked above in shape_lib_gen.py, and then debug what TorchToLinalg is doing wrong.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions