Skip to content

Commit d3f7c34

Browse files
eellisonpytorchmergebot
authored andcommitted
Enable aten-aten decomps (pytorch#85921)
Invokes aten-aten decomps with re-entrant FakeMode. These decomps are being used in other places, so it's good to unify the path static fake tensor takes / get additional testing etc. There is also an instance where we return different devices with cpu/cuda which this fixes ([batch_norm](https://github.com/pytorch/pytorch/blob/master/torch/_decomp/decompositions.py#L1374)) Pull Request resolved: pytorch#85921 Approved by: https://github.com/ezyang
1 parent af9c6bc commit d3f7c34

File tree

8 files changed

+64
-35
lines changed

8 files changed

+64
-35
lines changed

functorch/test/test_aotdispatch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
880880
xfail('mvlgamma', 'mvlgamma_p_3'), # aten.digamma_.default - couldn't find symbolic meta function/decom...
881881
xfail('mvlgamma', 'mvlgamma_p_5'), # aten.digamma_.default - couldn't find symbolic meta function/decom...
882882
xfail('nanmedian', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition
883-
xfail('native_layer_norm', ''), # could not find kernel
884883
xfail('nn.functional._scaled_dot_product_attention', ''), # Cannot call sizes() on tensor with symbolic ...
885884
xfail('nn.functional.adaptive_avg_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
886885
xfail('nn.functional.adaptive_avg_pool2d', ''), # aten._adaptive_avg_pool2d_backward.default - couldn't ...
@@ -923,7 +922,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
923922
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
924923
xfail('nn.functional.kl_div', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
925924
xfail('nn.functional.l1_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
926-
xfail('nn.functional.layer_norm', ''), # could not find kernel
927925
xfail('nn.functional.linear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
928926
xfail('nn.functional.local_response_norm', ''), # aten.fill.Scalar - couldn't find symbolic meta functio...
929927
xfail('nn.functional.max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides

test/test_fake_tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,15 @@ def fn(tensors):
464464
inputs = [a, b]
465465
ref = fn(inputs)
466466

467+
def test_fake_tensor_batch_norm_cpu(self):
468+
with torch._subclasses.CrossRefFakeMode():
469+
m = torch.nn.Sequential(
470+
torch.nn.BatchNorm2d(10),
471+
torch.nn.ReLU(),
472+
)
473+
m.eval()
474+
out = m(torch.randn([2, 10, 8, 8]))
475+
467476
def test_shared_storage_invalidation(self):
468477
with FakeTensorMode():
469478
x = torch.tensor([1.])

test/test_ops.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,10 +1811,6 @@ def test_refs_are_in_decomp_table(self, op):
18111811
"linalg.norm",
18121812
"linalg.svd",
18131813
"linalg.svdvals",
1814-
"nn.functional.binary_cross_entropy_with_logits",
1815-
"nn.functional.huber_loss",
1816-
"nn.functional.logsigmoid",
1817-
"nn.functional.multilabel_soft_margin_loss",
18181814
"pca_lowrank",
18191815
"roll",
18201816
"svd_lowrank",

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,8 @@ def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModul
843843
def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS
844844
def _set_conj(x: Tensor, conj: _bool) -> None: ...
845845
def _set_neg(x: Tensor, neg: _bool) -> None: ...
846-
def _add_meta_to_tls_dispatch_include() -> None: ...
846+
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
847847
def _meta_in_tls_dispatch_include() -> _bool: ...
848-
def _remove_meta_from_tls_dispatch_include() -> None: ...
849848
def _has_storage(x: Tensor) -> _bool: ...
850849
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
851850
# NB: There is no Capsule type in typing, see

torch/_decomp/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818

1919
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
2020

21+
# decompositions which have been disabled as meta kernel implementations,
22+
# usually due to mismatching strides, aliasing, or other inconsistent property
23+
_disabled_meta_decomps = set()
24+
2125

2226
def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False):
2327
"""
@@ -105,6 +109,11 @@ def add_op_to_table(aten_op):
105109
name = op_overload._schema.name
106110
if op_overload._schema.overload_name:
107111
name += "." + op_overload._schema.overload_name
112+
113+
if disable_meta:
114+
global _disabled_meta_decomps
115+
_disabled_meta_decomps.add(op_overload)
116+
108117
if (
109118
not disable_meta
110119
# TorchScript dumps a bunch of extra nonsense overloads

torch/_decomp/decompositions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,12 +1284,8 @@ def native_layer_norm_backward(
12841284
if M <= 0 or N <= 0:
12851285
return (
12861286
input.new_zeros(input_shape) if output_mask[0] else None,
1287-
input.new_zeros(input_shape[axis:])
1288-
if output_mask[1] and weight_cast
1289-
else None,
1290-
input.new_zeros(input_shape[axis:])
1291-
if output_mask[2] and bias_cast
1292-
else None,
1287+
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
1288+
input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
12931289
)
12941290

12951291
x_hat = (input_cast - mean) * rstd

torch/_subclasses/fake_tensor.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,20 @@ def get_schema_info(func):
115115
return torch._C._SchemaInfo(func._schema) # type: ignore[attr-defined]
116116

117117

118+
# many of the decompositions registered to torch/_prims do not at the moment model
119+
# aliasing or strides, so as an incremental step, just enable the decompositions in
120+
# torch/_decomp/decompositions.py.
121+
# decomps are used for aot autograd tracing so we would like to unify on their
122+
# implementation and add additional testing to them
123+
@functools.lru_cache(None)
124+
def torch_decomp_decompositions(func):
125+
from torch._decomp import decomposition_table
126+
127+
decompositions = torch._decomp.decompositions
128+
decomp_attrs = [getattr(decompositions, attr) for attr in dir(decompositions)]
129+
return decomposition_table[func] in decomp_attrs
130+
131+
118132
def tree_flatten_only(ty: Type[T], pytree: PyTree):
119133
flat_vals, _ = tree_flatten(pytree)
120134
return [elem for elem in flat_vals if isinstance(elem, ty)]
@@ -302,7 +316,8 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs):
302316
input_device = new_kwargs["device"]
303317
out_device = input_device if input_device else new_kwargs["input"].device
304318
new_kwargs["device"] = torch.device("meta")
305-
r = func(*args, **new_kwargs)
319+
inp = new_kwargs.pop("input")
320+
r = func(inp, **new_kwargs)
306321
return fake_mode.fake_tensor_converter(fake_mode, r, out_device)
307322

308323

@@ -329,7 +344,7 @@ def to_copy(fake_mode, func, *args, **kwargs):
329344

330345
input_device = new_kwargs.pop("device", None)
331346
out_device = input_device if input_device else new_kwargs["input"].device
332-
with no_dispatch(), in_kernel_invocation_manager(fake_mode):
347+
with in_kernel_invocation_manager(fake_mode):
333348
input = new_kwargs.pop("input").to("meta")
334349
return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device)
335350

@@ -417,18 +432,19 @@ def nyi(fake_mode, func, *args, **kwargs):
417432
@contextlib.contextmanager
418433
def in_kernel_invocation_manager(fake_mode):
419434
# See: note [Fake Tensor Dispatch Keys]
435+
prev_in_kernel = fake_mode.in_kernel_invocation
420436
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
421-
prev = fake_mode.in_kernel_invocation
437+
assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
422438

439+
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
423440
fake_mode.in_kernel_invocation = True
424-
if not meta_in_tls:
425-
torch._C._add_meta_to_tls_dispatch_include()
441+
torch._C._set_meta_in_tls_dispatch_include(True)
426442
try:
427443
yield
428444
finally:
429-
fake_mode.in_kernel_invocation = prev
430-
if not meta_in_tls:
431-
torch._C._remove_meta_from_tls_dispatch_include()
445+
fake_mode.in_kernel_invocation = prev_in_kernel
446+
torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
447+
del guard
432448

433449

434450
class FakeTensor(torch.Tensor):
@@ -728,14 +744,15 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
728744
# is written to must be invalidated
729745
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
730746

747+
from torch._decomp import _disabled_meta_decomps, decomposition_table
748+
731749
# IDK: feels bad man, sym_numel on as_strided infinite loops otherwise
732750
if (
733751
has_symbolic_sizes
734752
and func not in self.functions_with_cpp_meta_impl_that_support_symint
735753
):
736754
# TODO: Find better approach for this
737755
# Avoid circular import
738-
from torch._decomp import decomposition_table
739756
from torch._meta_registrations import meta_table
740757

741758
with no_dispatch():
@@ -759,6 +776,15 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
759776
if r is not NotImplemented:
760777
return r
761778

779+
if (
780+
func in decomposition_table
781+
and torch_decomp_decompositions(func)
782+
and func not in _disabled_meta_decomps
783+
and all(not e.is_sparse for e in flat_arg_fake_tensors)
784+
):
785+
with self:
786+
return decomposition_table[func](*args, **kwargs)
787+
762788
# prims already wrap FakeTensor inputs to FakeTensor outputs
763789
# and do device logic, we dont need do anything but run them
764790
# and ensure that Meta kernels are dispatched to (see)

torch/csrc/Module.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,25 +1386,21 @@ Call this whenever a new thread is created in order to propagate values from
13861386
py_module.def(
13871387
"_has_storage", [](const at::Tensor& x) { return x.has_storage(); });
13881388

1389-
py_module.def("_add_meta_to_tls_dispatch_include", []() {
1389+
py_module.def("_set_meta_in_tls_dispatch_include", [](bool meta_in_tls) {
13901390
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
13911391
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
1392-
local_keyset.included_ = local_keyset.included_ | key_set;
1393-
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
1394-
});
1395-
py_module.def("_remove_meta_from_tls_dispatch_include", []() {
1396-
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
1397-
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
1398-
auto k = key_set.highestBackendKey();
1399-
local_keyset.included_ = local_keyset.included_.remove_backend(k);
1392+
if (meta_in_tls) {
1393+
local_keyset.included_ = local_keyset.included_ | key_set;
1394+
} else {
1395+
local_keyset.included_ =
1396+
local_keyset.included_.remove_backend(c10::BackendComponent::MetaBit);
1397+
}
14001398
c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
14011399
});
14021400

14031401
py_module.def("_meta_in_tls_dispatch_include", []() {
14041402
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
1405-
c10::DispatchKeySet key_set({at::DispatchKey::Meta});
1406-
auto k = key_set.highestBackendKey();
1407-
return local_keyset.included_.has_backend(k);
1403+
return local_keyset.included_.has_backend(c10::BackendComponent::MetaBit);
14081404
});
14091405

14101406
py_module.def("_dump_local_tls_set", []() {

0 commit comments

Comments
 (0)