Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix overload ambiguity with functional ops; fix _foreach op grouping (p…
…ytorch#80556) This should fix the last issue that @anijain2305 hit when running ResNet with TorchDynamo <> functionalization. Today if you try to call an `OpOverloadPacket` from python with some arguments, we will use the types of those arguments to perform overload resolution. With some functional variants of ops, this can be ambiguous. Today this affects just one op: `_fused_moving_avg_obs_fq_helper`, although it would potentially affect e.g. `native_batch_norm` in the future. Example: ``` # There are technically two overloads: # torch.ops.aten._fused_moving_avg_obs_fq_helper.default (returns 2 argument, mutates 4 of its inputs inplace) # torch.ops.aten._fused_moving_avg_obs_fq_helper.functional (returns 6 argument, mutates none of its inputs) # We pick the wrong one - no way to know that we should pick the functional one, just from the call site. outs = torch.ops.aten._fused_moving_avg_obs_fq_helper(a, a, a, a, a, a, a, 1.0, 0, 1, 0) # raises an error - tries to call the overload with only 2 returns return _fused_moving_avg_obs_fq_helper_functional[5] ``` Specifically, functionalization will bake `_fused_moving_avg_obs_fq_helper.functional` into the graph, but when AOTAutograd tries to compile with TorchScript, it needs to remove the overload name (TS doesn't know how to parse overload names directly, so we need to remove the overload name and let it infer the right overload at runtime later- so it picks the wrong one). The situation is pretty similar to inplace; `ops.aten.add` and `ops.aten.add_` represent two different `OverloadPacket` objects; they can't be overloads of the same op, because their schemas would be ambiguous - the alias annotations are different, but that isn't enough to disambiguate). In this PR, I try to fix the situation in a pretty similar way to how we handle `inplace` in the data model: `inplace` ops get their own base operator name, but they are represented as a flag inside of `BaseOperatorName` in the data model. Two other important changes that I made as part of this PR: (1) Originally, there were ~100 different `*_functional` operators: e.g. we had operators named `resize.functional` and `zero.functional`. The `_functional` bit isn't actually necessary in most cases: it's only necessary for operators that **also** have a `SchemaKind.mutable` variant, where `_fused_moving_avg_obs_fq_helper` is the only op that fits that description today. So I removed the unnecessary notion of "functional" from those other ops. I also added a bunch of assertions to force this restriction. I think that makes more sense in the long run, because it eliminates an unnecessary difference in the model. E.g. we don't have `add_.Tensor` and `add.Tensor_functional`. We just have `add_.Tensor` and `add.Tensor`. (2) I noticed that we actually still weren't pairing up a bunch of `_foreach` operators correctly, because their input arguments were different (`self` vs. `tensors`). Since they're private API's, I went ahead and changed the argument names directly so they get matched up. Before this PR, we were generating a separate `_foreach_add` and `_foreach_add.functional` variant in a bunch of cases, that really did the same thing (but happened to have a different name for the first argument). Pull Request resolved: pytorch#80556 Approved by: https://github.com/ezyang, https://github.com/albanD
- Loading branch information