diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index f3eb3236676053..44ec2b7f6f5af5 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -16,7 +16,7 @@ from torch.fx.experimental.symbolic_shapes import sym_float, eval_guards, fx_placeholder_vals from torch.testing._internal.common_device_type import ops from torch._C import _disabled_torch_function_impl -from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule, has_proxy +from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule from torch.utils._pytree import tree_map from torch import nn import re @@ -368,6 +368,12 @@ def f_backward(x): for f in [f_grad, f_backward]: self._test(f, [torch.randn(3, requires_grad=True)]) + def test_pickle_issue89626(self): + import pickle + x = torch.randn(2) + make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x) + pickle.dumps(x) + def test_inplace_metadata(self): def f(x): x = x.clone() @@ -659,19 +665,6 @@ def f(x, w): for a, b in zip(out_graph.nodes, out_graph2.nodes): self.assertEqual(a.op, b.op) - def test_has_proxy(self): - foo = torch.randn(5) - - def f(x): - self.assertFalse(has_proxy(foo)) - self.assertTrue(has_proxy(x)) - y = x.cos() - self.assertTrue(has_proxy(y)) - return y - - self.assertFalse(has_proxy(torch.randn(5))) - make_fx(f)(torch.randn(5)) - def test_strides(self): def f(x): self.assertTrue(x.is_contiguous()) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index ddeb8c97825f04..ec7f92ad469442 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -5,6 +5,7 @@ import torch from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils.weak import WeakTensorRefKey def safe_is_leaf(t): @@ -58,39 +59,6 @@ def go(m1, m2): return go(m1, m2) -# torch.Tensors cannot be used as a key in a dictionary -# because they define a custom __eq__ function which when used -# to resolve hash collisions will throw when comparing tensors: -# "RuntimeError: bool value of Tensor with more than one value is ambiguous." -# To avoid that, we use an object which will hold a Tensor and use -# its id for both hashing and equality. -# In order to use this as a weak key reference, we cannot -# simply use weakref.WeakKeyDictionary because the newly constructed -# WeakTensorRefKey only use would be a dictionary so it would have no strong -# references. -# To get around this issue, we can use it as a normal key, and then set -# `weakref.finalize` to delete the key when its contained tensor dies. - - -class WeakTensorRefKey(object): - def __init__(self, ten): - self.ten = weakref.ref(ten) - # store id since as soon as ten is deallocated - # the old id will no longer be recoverable, and - # we need to be able to remove the WeakTensorRefKey - # from the dictionary by hashing it to the same - # value it had when ten was alive - self.id = id(self.ten()) - - def __hash__(self): - return self.id - - def __eq__(self, other): - if id(self) == id(other): - return True - return self.id == other.id - - # This is a class for converting multiple tensors into meta tensors which # share the same view/storage structure. The operation model is you allocate # one of these, and then call it repeatedly on all the tensors you want to diff --git a/torch/distributed/_spmd/comm_tensor.py b/torch/distributed/_spmd/comm_tensor.py index b02a0b63d7c2fa..703e547dbaaee1 100644 --- a/torch/distributed/_spmd/comm_tensor.py +++ b/torch/distributed/_spmd/comm_tensor.py @@ -7,9 +7,9 @@ from torch._C import _disabled_torch_function_impl from torch.fx.experimental.proxy_tensor import ( _ProxyTensor, + get_innermost_proxy_mode, fetch_tensor_proxy, - get_proxy, - get_proxy_slots, + get_proxy_slot, set_proxy_slot, track_tensor_tree, ) @@ -51,13 +51,11 @@ def wrap(work, e): return (tree_map(partial(wrap, work), result[0]), work) -def _get_tracer(obj: Any) -> Optional[torch.fx.Tracer]: - slots = get_proxy_slots(obj) - if slots is None: +def _get_tracer() -> Optional[torch.fx.Tracer]: + mode = get_innermost_proxy_mode() + if mode is None: return None - keys = tuple(slots.keys()) - assert len(keys) == 1 - return keys[0] + return mode.tracer class CommTensor(torch.Tensor): @@ -105,13 +103,14 @@ class CommTensor(torch.Tensor): @staticmethod def __new__(cls, tensor: torch.Tensor): t = tensor._tensor if isinstance(tensor, CommTensor) else tensor - if _get_tracer(t) is None: + if get_innermost_proxy_mode() is None: # noop for eager mode return tensor # Use non-CommTensor to avoid nested CommTensor Wrapping r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad) # The tensor object wrapped by this CommTensor + # NB: THIS CAN BE A CommTensor; see test_nested_comm_tensor_wrapping r._tensor = tensor # type: ignore[attr-defined] # Record the LAST `work` object returned by collective communication # operations. If this is None, it means no collectives have called @@ -143,7 +142,14 @@ def unwrap(e: Any): nonlocal tracer, work work = e._work - tracer = _get_tracer(e._tensor) + # TODO(ezyang): I don't really understand what's going on + # here, but it seems that tracer doesn't reflect whether or + # not there is ambient tracing going on, but rather, whether + # or not we will trace THIS particular invocation. If we + # have a nested CommTensor, the outer layer doesn't actually + # trace and we only trace the inner layer + if not isinstance(e._tensor, CommTensor): + tracer = _get_tracer() if work is not None: if tracer is not None: @@ -151,7 +157,7 @@ def unwrap(e: Any): proxy_res = tracer.create_proxy( # type: ignore[union-attr] 'call_function', _wait_comm, - (get_proxy(e._tensor).proxy,), + (get_proxy_slot(e._tensor, tracer).proxy,), {}, name="wait_comm" ) @@ -198,6 +204,7 @@ def set_work(work: torch.distributed._Work, e: Any): # get proxy for output tuple proxy_res = func(*proxy_args, **proxy_kwargs) + assert isinstance(proxy_res, torch.fx.Proxy) # insert a node that wraps the output tuple into # _CommResult(tensor, work) comm_result_proxy = tracer.create_proxy( # type: ignore[union-attr] @@ -227,7 +234,7 @@ def set_work(work: torch.distributed._Work, e: Any): flat_args, args_spec = tree_flatten(unwrapped_args[0]) flat_out, out_spec = tree_flatten(out[0]) for a, o in zip(flat_args, flat_out): - set_proxy_slot(a, tracer, get_proxy(o)) + set_proxy_slot(a, tracer, get_proxy_slot(o, tracer)) return out else: diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index c2c60eaee0838b..56e0d92292ebbf 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -24,8 +24,9 @@ from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode from torch.fx import Proxy from torch import SymInt, SymFloat +from torch.utils.weak import WeakTensorKeyDictionary -__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy", "py_sym_types"] +__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"] aten = torch.ops.aten prim = torch.ops.prim @@ -62,21 +63,20 @@ def is_sym_node(node): return "val" in node.meta and isinstance(node.meta['val'], py_sym_types) def set_proxy_slot(obj, tracer, proxy): - assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) - d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) # type: ignore[call-overload] - assert isinstance(d, weakref.WeakKeyDictionary) - # NB: Never clobber pre-existing proxy. Although the proxies - # are in principle equivalent, when we do graph partitioning - # we need there not to be spurious dependencies on tangent inputs. - # This works because primals get their SymInts set first, and - # THEN later we allocate tangent inputs. Make sure if a SymInt - # is derivable from a primal that we use that. - # - # However, we DO want to clobber proxies whenever we run an inplace operation - # on a tensor, and it affects the metadata on the proxy. - # This doesn't really apply to SymInts/SymFloats though, which are immutable. - if tracer not in d or isinstance(obj, torch.Tensor): - d[tracer] = proxy + if isinstance(obj, torch.Tensor): + # We DO want to clobber proxies whenever we run an inplace operation + # on a tensor, and it affects the metadata on the proxy. + tracer.tensor_tracker[obj] = proxy + else: + # NB: Never clobber pre-existing proxy. Although the proxies + # are in principle equivalent, when we do graph partitioning + # we need there not to be spurious dependencies on tangent inputs. + # This works because primals get their SymInts set first, and + # THEN later we allocate tangent inputs. Make sure if a SymInt + # is derivable from a primal that we use that. + assert isinstance(obj, SymNode), type(obj) + if obj not in tracer.symnode_tracker: + tracer.symnode_tracker[obj] = proxy def has_proxy_slot(obj, tracer): assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) @@ -86,36 +86,17 @@ def has_proxy_slot(obj, tracer): # the transform argument is handy if you need to extract a subfield from # the successfully looked up result (but NOT the default.) def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): - assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) - d = obj.__dict__.get(proxy_slot) # type: ignore[call-overload] - if not d: + if isinstance(obj, torch.Tensor): + tracker = tracer.tensor_tracker + else: + assert isinstance(obj, SymNode), type(obj) + tracker = tracer.symnode_tracker + + if obj not in tracker: if default is no_default: - raise KeyError(f"{obj} is not tracked with proxy for {tracer}") + raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}") return default - assert isinstance(d, weakref.WeakKeyDictionary) - if tracer not in d: - if default is no_default: - raise KeyError(f"{obj} is not tracked with proxy for {tracer}") - else: - return default - return transform(d[tracer]) - - -def get_proxy_slots(obj): - return obj.__dict__.get(proxy_slot) - - -# Gets the proxy for a tensor, if it exists. -def get_proxy(obj): - res = get_proxy_slots(obj) - if res is None: - return None - vals = tuple(res.values()) - assert len(vals) == 1 - return vals[0] - -def has_proxy(obj): - return get_proxy(obj) is not None + return transform(tracker[obj]) def snapshot_fake(val): return val.detach() @@ -404,6 +385,8 @@ def can_handle_tensor(x): class PythonKeyTracer(Tracer): def __init__(self): super().__init__() + self.tensor_tracker = WeakTensorKeyDictionary() + self.symnode_tracker = weakref.WeakKeyDictionary() # type: ignore[var-annotated] # In general, we don't want to make modules leaves. In principle, users of # this tracer might want to override this in order to turn a couple specific @@ -580,6 +563,9 @@ def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, deco super().__init__(module, **kwargs) self.new_graph = new_graph self.tracer = torch.fx.proxy.GraphAppendingTracer(self.new_graph) + # Blegh + self.tracer.tensor_tracker = WeakTensorKeyDictionary() # type: ignore[attr-defined] + self.tracer.symnode_tracker = weakref.WeakKeyDictionary() # type: ignore[attr-defined] self.decomposition_table = decomposition_table if self.decomposition_table is None: self.decomposition_table = {} @@ -715,6 +701,13 @@ def get_torch_dispatch_modes(): return torch.utils._python_dispatch._get_current_dispatch_mode_stack() +def get_innermost_proxy_mode(): + for m in reversed(torch.utils._python_dispatch._get_current_dispatch_mode_stack()): + if isinstance(m, ProxyTorchDispatchMode): + return m + return None + + @contextlib.contextmanager def disable_proxy_modes_tracing(): # TODO: This probably doesn't correctly also disable ProxySymDispatchMode diff --git a/torch/utils/weak.py b/torch/utils/weak.py new file mode 100644 index 00000000000000..205ca50679c81e --- /dev/null +++ b/torch/utils/weak.py @@ -0,0 +1,79 @@ +import weakref +from collections.abc import MutableMapping +from typing import Dict + + +__all__ = ['WeakTensorRefKey', 'WeakTensorKeyDictionary'] + + +# Utility classes for working with weak references to tensors + +# torch.Tensors cannot be used as a key in a dictionary +# because they define a custom __eq__ function which when used +# to resolve hash collisions will throw when comparing tensors: +# "RuntimeError: bool value of Tensor with more than one value is ambiguous." +# To avoid that, we use an object which will hold a Tensor and use +# its id for both hashing and equality. +# In order to use this as a weak key reference, we cannot +# simply use weakref.WeakKeyDictionary because the newly constructed +# WeakTensorRefKey only use would be a dictionary so it would have no strong +# references. +# To get around this issue, we can use it as a normal key, and then set +# `weakref.finalize` to delete the key when its contained tensor dies. + + +class WeakTensorRefKey(object): + def __init__(self, ten): + self.ten = weakref.ref(ten) + # store id since as soon as ten is deallocated + # the old id will no longer be recoverable, and + # we need to be able to remove the WeakTensorRefKey + # from the dictionary by hashing it to the same + # value it had when ten was alive + self.id = id(self.ten()) + + def __hash__(self): + return self.id + + def __eq__(self, other): + if id(self) == id(other): + return True + return self.id == other.id + +class WeakTensorKeyDictionary(MutableMapping): + data: Dict[WeakTensorRefKey, object] + + def __init__(self): + self.data = {} + + def __contains__(self, k): + return WeakTensorRefKey(k) in self.data + + def __len__(self): + return len(self.data) + + def __iter__(self): + def generator(): + for wk in self.data: + k = wk.ten() + if k is not None: + yield k + return generator() + + def __getitem__(self, k): + return self.data[WeakTensorRefKey(k)] + + def __setitem__(self, k, v): + wk = WeakTensorRefKey(k) + self_weak_ref = weakref.ref(self) + + def del_ten(): + self_ref = self_weak_ref() + if self_ref is None: + return + self_ref.data.pop(wk, None) + weakref.finalize(k, del_ten) + self.data[wk] = v + + def __delitem__(self, k): + del self.data[WeakTensorRefKey(k)]