-
Notifications
You must be signed in to change notification settings - Fork 610
Description
As discussed in pytorch/pytorch#73050 (comment), there are a few ops that don't correctly annotate that they mutate their operands. It seems like those are aten::batch_norm and aten::layer_norm.
When I revamped our ODS generator code, I tried correcting those exceptions w.r.t. the HasValueSemantics and ReadOnly traits, but it seems we were relying on the old, incorrect annotation (which I think was okay, since it only matters in the training case, which we haven't implemented yet)
torch-mlir/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py
Line 258 in a5fe0cf
| # TODO: Handle some exceptions of incorrectly annotated ops. |
To work on this, you just have to uncomment the code linked above and regenerate the ODS for torch.aten.batch_norm and see what breaks in the tests. I dug into it a little bit, and it seems like we will need some special handling in ReduceOpVariants to convert torch.aten.batch_norm to value semantics when training == false.