Skip to content

Commit

Permalink
fix overload ambiguity with functional ops; fix _foreach op grouping (p…
Browse files Browse the repository at this point in the history
…ytorch#80556)

This should fix the last issue that @anijain2305 hit when running ResNet with TorchDynamo <> functionalization.

Today if you try to call an `OpOverloadPacket` from python with some arguments, we will use the types of those arguments to perform overload resolution. With some functional variants of ops, this can be ambiguous.

Today this affects just one op: `_fused_moving_avg_obs_fq_helper`, although it would potentially affect e.g. `native_batch_norm` in the future.

Example:
```
# There are technically two overloads:
# torch.ops.aten._fused_moving_avg_obs_fq_helper.default (returns 2 argument, mutates 4 of its inputs inplace)
# torch.ops.aten._fused_moving_avg_obs_fq_helper.functional (returns 6 argument, mutates none of its inputs)

# We pick the wrong one - no way to know that we should pick the functional one, just from the call site.
outs = torch.ops.aten._fused_moving_avg_obs_fq_helper(a, a, a, a, a, a, a, 1.0, 0, 1, 0)
# raises an error - tries to call the overload with only 2 returns
return _fused_moving_avg_obs_fq_helper_functional[5]
```

Specifically, functionalization will bake `_fused_moving_avg_obs_fq_helper.functional` into the graph, but when AOTAutograd tries to compile with TorchScript, it needs to remove the overload name (TS doesn't know how to parse overload names directly, so we need to remove the overload name and let it infer the right overload at runtime later- so it picks the wrong one).

The situation is pretty similar to inplace; `ops.aten.add` and `ops.aten.add_` represent two different `OverloadPacket` objects; they can't be overloads of the same op, because their schemas would be ambiguous - the alias annotations are different, but that isn't enough to disambiguate).

In this PR, I try to fix the situation in a pretty similar way to how we handle `inplace` in the data model: `inplace` ops get their own base operator name, but they are represented as a flag inside of `BaseOperatorName` in the data model.

Two other important changes that I made as part of this PR:

(1) Originally, there were ~100 different `*_functional` operators: e.g. we had operators named `resize.functional` and `zero.functional`. The `_functional` bit isn't actually necessary in most cases: it's only necessary for operators that **also** have a `SchemaKind.mutable` variant, where `_fused_moving_avg_obs_fq_helper` is the only op that fits that description today. So I removed the unnecessary notion of "functional" from those other ops. I also added a bunch of assertions to force this restriction.

I think that makes more sense in the long run, because it eliminates an unnecessary difference in the model. E.g. we don't have `add_.Tensor` and `add.Tensor_functional`. We just have `add_.Tensor` and `add.Tensor`.

(2) I noticed that we actually still weren't pairing up a bunch of `_foreach` operators correctly, because their input arguments were different (`self` vs. `tensors`). Since they're private API's, I went ahead and changed the argument names directly so they get matched up. Before this PR, we were generating a separate `_foreach_add` and `_foreach_add.functional` variant in a bunch of cases, that really did the same thing (but happened to have a different name for the first argument).

Pull Request resolved: pytorch#80556
Approved by: https://github.com/ezyang, https://github.com/albanD
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Jul 6, 2022
1 parent ce0786a commit 960758b
Show file tree
Hide file tree
Showing 13 changed files with 216 additions and 150 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/FunctionalizeFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet,
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::resize_functional(self_, size, memory_format);
tmp_output = at::resize(self_, size, memory_format);
}

auto itemsize = self.dtype().itemsize();
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ Tensor normal_meta(const Tensor& mean, const Tensor& std, c10::optional<Generato
return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, gen);
}

// functional variant, only used by the functionalization pass.
Tensor normal_functional(const Tensor& self, double mean, double std, c10::optional<at::Generator> generator) {
return self.clone().normal_(mean, std, generator);
}

// ==================================================== Random ========================================================

template<typename RNG>
Expand Down
Loading

0 comments on commit 960758b

Please sign in to comment.