Skip to content

Commit 7ec12a5

Browse files
Revert "Enable aten-aten decomps (pytorch#85921)"
This reverts commit 62e4f51. Reverted pytorch#85921 on behalf of https://github.com/huydhn due to Sorry for reverting your PR. I think it breaks a dynamo test in trunk https://hud.pytorch.org/pytorch/pytorch/commit/62e4f51efdf98a3a91d29efa55e5665d5398b464
1 parent b0ceb8e commit 7ec12a5

File tree

8 files changed

+35
-64
lines changed

8 files changed

+35
-64
lines changed

functorch/test/test_aotdispatch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
880880
xfail('mvlgamma', 'mvlgamma_p_3'), # aten.digamma_.default - couldn't find symbolic meta function/decom...
881881
xfail('mvlgamma', 'mvlgamma_p_5'), # aten.digamma_.default - couldn't find symbolic meta function/decom...
882882
xfail('nanmedian', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition
883+
xfail('native_layer_norm', ''), # could not find kernel
883884
xfail('nn.functional._scaled_dot_product_attention', ''), # Cannot call sizes() on tensor with symbolic ...
884885
xfail('nn.functional.adaptive_avg_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
885886
xfail('nn.functional.adaptive_avg_pool2d', ''), # aten._adaptive_avg_pool2d_backward.default - couldn't ...
@@ -922,6 +923,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
922923
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
923924
xfail('nn.functional.kl_div', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
924925
xfail('nn.functional.l1_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
926+
xfail('nn.functional.layer_norm', ''), # could not find kernel
925927
xfail('nn.functional.linear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
926928
xfail('nn.functional.local_response_norm', ''), # aten.fill.Scalar - couldn't find symbolic meta functio...
927929
xfail('nn.functional.max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides

test/test_fake_tensor.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -464,15 +464,6 @@ def fn(tensors):
464464
inputs = [a, b]
465465
ref = fn(inputs)
466466

467-
def test_fake_tensor_batch_norm_cpu(self):
468-
with torch._subclasses.CrossRefFakeMode():
469-
m = torch.nn.Sequential(
470-
torch.nn.BatchNorm2d(10),
471-
torch.nn.ReLU(),
472-
)
473-
m.eval()
474-
out = m(torch.randn([2, 10, 8, 8]))
475-
476467
def test_shared_storage_invalidation(self):
477468
with FakeTensorMode():
478469
x = torch.tensor([1.])

test/test_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,6 +1811,10 @@ def test_refs_are_in_decomp_table(self, op):
18111811
"linalg.norm",
18121812
"linalg.svd",
18131813
"linalg.svdvals",
1814+
"nn.functional.binary_cross_entropy_with_logits",
1815+
"nn.functional.huber_loss",
1816+
"nn.functional.logsigmoid",
1817+
"nn.functional.multilabel_soft_margin_loss",
18141818
"pca_lowrank",
18151819
"roll",
18161820
"svd_lowrank",

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,8 +843,9 @@ def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModul
843843
def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS
844844
def _set_conj(x: Tensor, conj: _bool) -> None: ...
845845
def _set_neg(x: Tensor, neg: _bool) -> None: ...
846-
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
846+
def _add_meta_to_tls_dispatch_include() -> None: ...
847847
def _meta_in_tls_dispatch_include() -> _bool: ...
848+
def _remove_meta_from_tls_dispatch_include() -> None: ...
848849
def _has_storage(x: Tensor) -> _bool: ...
849850
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
850851
# NB: There is no Capsule type in typing, see

torch/_decomp/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818

1919
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
2020

21-
# decompositions which have been disabled as meta kernel implementations,
22-
# usually due to mismatching strides, aliasing, or other inconsistent property
23-
_disabled_meta_decomps = set()
24-
2521

2622
def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False):
2723
"""
@@ -109,11 +105,6 @@ def add_op_to_table(aten_op):
109105
name = op_overload._schema.name
110106
if op_overload._schema.overload_name:
111107
name += "." + op_overload._schema.overload_name
112-
113-
if disable_meta:
114-
global _disabled_meta_decomps
115-
_disabled_meta_decomps.add(op_overload)
116-
117108
if (
118109
not disable_meta
119110
# TorchScript dumps a bunch of extra nonsense overloads

torch/_decomp/decompositions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,8 +1284,12 @@ def native_layer_norm_backward(
12841284
if M <= 0 or N <= 0:
12851285
return (
12861286
input.new_zeros(input_shape) if output_mask[0] else None,
1287-
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
1288-
input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
1287+
input.new_zeros(input_shape[axis:])
1288+
if output_mask[1] and weight_cast
1289+
else None,
1290+
input.new_zeros(input_shape[axis:])
1291+
if output_mask[2] and bias_cast
1292+
else None,
12891293
)
12901294

12911295
x_hat = (input_cast - mean) * rstd

torch/_subclasses/fake_tensor.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,6 @@ def get_schema_info(func):
115115
return torch._C._SchemaInfo(func._schema) # type: ignore[attr-defined]
116116

117117

118-
# many of the decompositions registered to torch/_prims do not at the moment model
119-
# aliasing or strides, so as an incremental step, just enable the decompositions in
120-
# torch/_decomp/decompositions.py.
121-
# decomps are used for aot autograd tracing so we would like to unify on their
122-
# implementation and add additional testing to them
123-
@functools.lru_cache(None)
124-
def torch_decomp_decompositions(func):
125-
from torch._decomp import decomposition_table
126-
127-
decompositions = torch._decomp.decompositions
128-
decomp_attrs = [getattr(decompositions, attr) for attr in dir(decompositions)]
129-
return decomposition_table[func] in decomp_attrs
130-
131-
132118
def tree_flatten_only(ty: Type[T], pytree: PyTree):
133119
flat_vals, _ = tree_flatten(pytree)
134120
return [elem for elem in flat_vals if isinstance(elem, ty)]
@@ -316,8 +302,7 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
316302
input_device = new_kwargs["device"]
317303
out_device = input_device if input_device else new_kwargs["input"].device
318304
new_kwargs["device"] = torch.device("meta")
319-
inp = new_kwargs.pop("input")
320-
r = func(inp, **new_kwargs)
305+
r = func(*args, **new_kwargs)
321306
return fake_mode.fake_tensor_converter(fake_mode, r, out_device)
322307

323308

@@ -344,7 +329,7 @@ def to_copy(fake_mode, func, *args, **kwargs):
344329

345330
input_device = new_kwargs.pop("device", None)
346331
out_device = input_device if input_device else new_kwargs["input"].device
347-
with in_kernel_invocation_manager(fake_mode):
332+
with no_dispatch(), in_kernel_invocation_manager(fake_mode):
348333
input = new_kwargs.pop("input").to("meta")
349334
return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device)
350335

@@ -432,19 +417,18 @@ def nyi(fake_mode, func, *args, **kwargs):
432417
@contextlib.contextmanager
433418
def in_kernel_invocation_manager(fake_mode):
434419
# See: note [Fake Tensor Dispatch Keys]
435-
prev_in_kernel = fake_mode.in_kernel_invocation
436420
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
437-
assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
421+
prev = fake_mode.in_kernel_invocation
438422

439-
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
440423
fake_mode.in_kernel_invocation = True
441-
torch._C._set_meta_in_tls_dispatch_include(True)
424+
if not meta_in_tls:
425+
torch._C._add_meta_to_tls_dispatch_include()
442426
try:
443427
yield
444428
finally:
445-
fake_mode.in_kernel_invocation = prev_in_kernel
446-
torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
447-
del guard
429+
fake_mode.in_kernel_invocation = prev
430+
if not meta_in_tls:
431+
torch._C._remove_meta_from_tls_dispatch_include()
448432

449433

450434
class FakeTensor(torch.Tensor):
@@ -744,15 +728,14 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
744728
# is written to must be invalidated
745729
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
746730

747-
from torch._decomp import _disabled_meta_decomps, decomposition_table
748-
749731
# IDK: feels bad man, sym_numel on as_strided infinite loops otherwise
750732
if (
751733
has_symbolic_sizes
752734
and func not in self.functions_with_cpp_meta_impl_that_support_symint
753735
):
754736
# TODO: Find better approach for this
755737
# Avoid circular import
738+
from torch._decomp import decomposition_table
756739
from torch._meta_registrations import meta_table
757740

758741
with no_dispatch():
@@ -776,15 +759,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
776759
if r is not NotImplemented:
777760
return r
778761

779-
if (
780-
func in decomposition_table
781-
and torch_decomp_decompositions(func)
782-
and func not in _disabled_meta_decomps
783-
and all(not e.is_sparse for e in flat_arg_fake_tensors)
784-
):
785-
with self:
786-
return decomposition_table[func](*args, **kwargs)
787-
788762
# prims already wrap FakeTensor inputs to FakeTensor outputs
789763
# and do device logic, we dont need do anything but run them
790764
# and ensure that Meta kernels are dispatched to (see)

torch/csrc/Module.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,21 +1386,25 @@ Call this whenever a new thread is created in order to propagate values from
13861386
py_module.def(
13871387
"_has_storage", [](const at::Tensor& x) { return x.has_storage(); });
13881388

1389-
py_module.def("_set_meta_in_tls_dispatch_include", [](bool meta_in_tls) {
1389+
py_module.def("_add_meta_to_tls_dispatch_include", []() {
13901390
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
13911391
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
1392-
if (meta_in_tls) {
1393-
local_keyset.included_ = local_keyset.included_ | key_set;
1394-
} else {
1395-
local_keyset.included_ =
1396-
local_keyset.included_.remove_backend(c10::BackendComponent::MetaBit);
1397-
}
1392+
local_keyset.included_ = local_keyset.included_ | key_set;
1393+
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
1394+
});
1395+
py_module.def("_remove_meta_from_tls_dispatch_include", []() {
1396+
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
1397+
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
1398+
auto k = key_set.highestBackendKey();
1399+
local_keyset.included_ = local_keyset.included_.remove_backend(k);
13981400
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
13991401
});
14001402

14011403
py_module.def("_meta_in_tls_dispatch_include", []() {
14021404
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
1403-
return local_keyset.included_.has_backend(c10::BackendComponent::MetaBit);
1405+
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
1406+
auto k = key_set.highestBackendKey();
1407+
return local_keyset.included_.has_backend(k);
14041408
});
14051409

14061410
py_module.def("_dump_local_tls_set", []() {

0 commit comments

Comments
 (0)