Skip to content

Commit

Permalink
Don't put tracing state on Tensor (pytorch#90628)
Browse files Browse the repository at this point in the history
Fixes pytorch#89626

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: pytorch#90628
Approved by: https://github.com/voznesenskym
  • Loading branch information
ezyang authored and pytorchmergebot committed Dec 15, 2022
1 parent 103029e commit 54563e6
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 103 deletions.
21 changes: 7 additions & 14 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.fx.experimental.symbolic_shapes import sym_float, eval_guards, fx_placeholder_vals
from torch.testing._internal.common_device_type import ops
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule, has_proxy
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
from torch.utils._pytree import tree_map
from torch import nn
import re
Expand Down Expand Up @@ -368,6 +368,12 @@ def f_backward(x):
for f in [f_grad, f_backward]:
self._test(f, [torch.randn(3, requires_grad=True)])

def test_pickle_issue89626(self):
import pickle
x = torch.randn(2)
make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x)
pickle.dumps(x)

def test_inplace_metadata(self):
def f(x):
x = x.clone()
Expand Down Expand Up @@ -659,19 +665,6 @@ def f(x, w):
for a, b in zip(out_graph.nodes, out_graph2.nodes):
self.assertEqual(a.op, b.op)

def test_has_proxy(self):
foo = torch.randn(5)

def f(x):
self.assertFalse(has_proxy(foo))
self.assertTrue(has_proxy(x))
y = x.cos()
self.assertTrue(has_proxy(y))
return y

self.assertFalse(has_proxy(torch.randn(5)))
make_fx(f)(torch.randn(5))

def test_strides(self):
def f(x):
self.assertTrue(x.is_contiguous())
Expand Down
34 changes: 1 addition & 33 deletions torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils.weak import WeakTensorRefKey


def safe_is_leaf(t):
Expand Down Expand Up @@ -58,39 +59,6 @@ def go(m1, m2):
return go(m1, m2)


# torch.Tensors cannot be used as a key in a dictionary
# because they define a custom __eq__ function which when used
# to resolve hash collisions will throw when comparing tensors:
# "RuntimeError: bool value of Tensor with more than one value is ambiguous."
# To avoid that, we use an object which will hold a Tensor and use
# its id for both hashing and equality.
# In order to use this as a weak key reference, we cannot
# simply use weakref.WeakKeyDictionary because the newly constructed
# WeakTensorRefKey only use would be a dictionary so it would have no strong
# references.
# To get around this issue, we can use it as a normal key, and then set
# `weakref.finalize` to delete the key when its contained tensor dies.


class WeakTensorRefKey(object):
def __init__(self, ten):
self.ten = weakref.ref(ten)
# store id since as soon as ten is deallocated
# the old id will no longer be recoverable, and
# we need to be able to remove the WeakTensorRefKey
# from the dictionary by hashing it to the same
# value it had when ten was alive
self.id = id(self.ten())

def __hash__(self):
return self.id

def __eq__(self, other):
if id(self) == id(other):
return True
return self.id == other.id


# This is a class for converting multiple tensors into meta tensors which
# share the same view/storage structure. The operation model is you allocate
# one of these, and then call it repeatedly on all the tensors you want to
Expand Down
31 changes: 19 additions & 12 deletions torch/distributed/_spmd/comm_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental.proxy_tensor import (
_ProxyTensor,
get_innermost_proxy_mode,
fetch_tensor_proxy,
get_proxy,
get_proxy_slots,
get_proxy_slot,
set_proxy_slot,
track_tensor_tree,
)
Expand Down Expand Up @@ -51,13 +51,11 @@ def wrap(work, e):
return (tree_map(partial(wrap, work), result[0]), work)


def _get_tracer(obj: Any) -> Optional[torch.fx.Tracer]:
slots = get_proxy_slots(obj)
if slots is None:
def _get_tracer() -> Optional[torch.fx.Tracer]:
mode = get_innermost_proxy_mode()
if mode is None:
return None
keys = tuple(slots.keys())
assert len(keys) == 1
return keys[0]
return mode.tracer


class CommTensor(torch.Tensor):
Expand Down Expand Up @@ -105,13 +103,14 @@ class CommTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor):
t = tensor._tensor if isinstance(tensor, CommTensor) else tensor
if _get_tracer(t) is None:
if get_innermost_proxy_mode() is None:
# noop for eager mode
return tensor

# Use non-CommTensor to avoid nested CommTensor Wrapping
r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
# The tensor object wrapped by this CommTensor
# NB: THIS CAN BE A CommTensor; see test_nested_comm_tensor_wrapping
r._tensor = tensor # type: ignore[attr-defined]
# Record the LAST `work` object returned by collective communication
# operations. If this is None, it means no collectives have called
Expand Down Expand Up @@ -143,15 +142,22 @@ def unwrap(e: Any):
nonlocal tracer, work

work = e._work
tracer = _get_tracer(e._tensor)
# TODO(ezyang): I don't really understand what's going on
# here, but it seems that tracer doesn't reflect whether or
# not there is ambient tracing going on, but rather, whether
# or not we will trace THIS particular invocation. If we
# have a nested CommTensor, the outer layer doesn't actually
# trace and we only trace the inner layer
if not isinstance(e._tensor, CommTensor):
tracer = _get_tracer()

if work is not None:
if tracer is not None:
# insert a node to the traced graph.
proxy_res = tracer.create_proxy( # type: ignore[union-attr]
'call_function',
_wait_comm,
(get_proxy(e._tensor).proxy,),
(get_proxy_slot(e._tensor, tracer).proxy,),
{},
name="wait_comm"
)
Expand Down Expand Up @@ -198,6 +204,7 @@ def set_work(work: torch.distributed._Work, e: Any):

# get proxy for output tuple
proxy_res = func(*proxy_args, **proxy_kwargs)
assert isinstance(proxy_res, torch.fx.Proxy)
# insert a node that wraps the output tuple into
# _CommResult(tensor, work)
comm_result_proxy = tracer.create_proxy( # type: ignore[union-attr]
Expand Down Expand Up @@ -227,7 +234,7 @@ def set_work(work: torch.distributed._Work, e: Any):
flat_args, args_spec = tree_flatten(unwrapped_args[0])
flat_out, out_spec = tree_flatten(out[0])
for a, o in zip(flat_args, flat_out):
set_proxy_slot(a, tracer, get_proxy(o))
set_proxy_slot(a, tracer, get_proxy_slot(o, tracer))

return out
else:
Expand Down
81 changes: 37 additions & 44 deletions torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode
from torch.fx import Proxy
from torch import SymInt, SymFloat
from torch.utils.weak import WeakTensorKeyDictionary

__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy", "py_sym_types"]
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"]
aten = torch.ops.aten
prim = torch.ops.prim

Expand Down Expand Up @@ -62,21 +63,20 @@ def is_sym_node(node):
return "val" in node.meta and isinstance(node.meta['val'], py_sym_types)

def set_proxy_slot(obj, tracer, proxy):
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) # type: ignore[call-overload]
assert isinstance(d, weakref.WeakKeyDictionary)
# NB: Never clobber pre-existing proxy. Although the proxies
# are in principle equivalent, when we do graph partitioning
# we need there not to be spurious dependencies on tangent inputs.
# This works because primals get their SymInts set first, and
# THEN later we allocate tangent inputs. Make sure if a SymInt
# is derivable from a primal that we use that.
#
# However, we DO want to clobber proxies whenever we run an inplace operation
# on a tensor, and it affects the metadata on the proxy.
# This doesn't really apply to SymInts/SymFloats though, which are immutable.
if tracer not in d or isinstance(obj, torch.Tensor):
d[tracer] = proxy
if isinstance(obj, torch.Tensor):
# We DO want to clobber proxies whenever we run an inplace operation
# on a tensor, and it affects the metadata on the proxy.
tracer.tensor_tracker[obj] = proxy
else:
# NB: Never clobber pre-existing proxy. Although the proxies
# are in principle equivalent, when we do graph partitioning
# we need there not to be spurious dependencies on tangent inputs.
# This works because primals get their SymInts set first, and
# THEN later we allocate tangent inputs. Make sure if a SymInt
# is derivable from a primal that we use that.
assert isinstance(obj, SymNode), type(obj)
if obj not in tracer.symnode_tracker:
tracer.symnode_tracker[obj] = proxy

def has_proxy_slot(obj, tracer):
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
Expand All @@ -86,36 +86,17 @@ def has_proxy_slot(obj, tracer):
# the transform argument is handy if you need to extract a subfield from
# the successfully looked up result (but NOT the default.)
def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
d = obj.__dict__.get(proxy_slot) # type: ignore[call-overload]
if not d:
if isinstance(obj, torch.Tensor):
tracker = tracer.tensor_tracker
else:
assert isinstance(obj, SymNode), type(obj)
tracker = tracer.symnode_tracker

if obj not in tracker:
if default is no_default:
raise KeyError(f"{obj} is not tracked with proxy for {tracer}")
raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}")
return default
assert isinstance(d, weakref.WeakKeyDictionary)
if tracer not in d:
if default is no_default:
raise KeyError(f"{obj} is not tracked with proxy for {tracer}")
else:
return default
return transform(d[tracer])


def get_proxy_slots(obj):
return obj.__dict__.get(proxy_slot)


# Gets the proxy for a tensor, if it exists.
def get_proxy(obj):
res = get_proxy_slots(obj)
if res is None:
return None
vals = tuple(res.values())
assert len(vals) == 1
return vals[0]

def has_proxy(obj):
return get_proxy(obj) is not None
return transform(tracker[obj])

def snapshot_fake(val):
return val.detach()
Expand Down Expand Up @@ -404,6 +385,8 @@ def can_handle_tensor(x):
class PythonKeyTracer(Tracer):
def __init__(self):
super().__init__()
self.tensor_tracker = WeakTensorKeyDictionary()
self.symnode_tracker = weakref.WeakKeyDictionary() # type: ignore[var-annotated]

# In general, we don't want to make modules leaves. In principle, users of
# this tracer might want to override this in order to turn a couple specific
Expand Down Expand Up @@ -580,6 +563,9 @@ def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, deco
super().__init__(module, **kwargs)
self.new_graph = new_graph
self.tracer = torch.fx.proxy.GraphAppendingTracer(self.new_graph)
# Blegh
self.tracer.tensor_tracker = WeakTensorKeyDictionary() # type: ignore[attr-defined]
self.tracer.symnode_tracker = weakref.WeakKeyDictionary() # type: ignore[attr-defined]
self.decomposition_table = decomposition_table
if self.decomposition_table is None:
self.decomposition_table = {}
Expand Down Expand Up @@ -715,6 +701,13 @@ def get_torch_dispatch_modes():
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()


def get_innermost_proxy_mode():
for m in reversed(torch.utils._python_dispatch._get_current_dispatch_mode_stack()):
if isinstance(m, ProxyTorchDispatchMode):
return m
return None


@contextlib.contextmanager
def disable_proxy_modes_tracing():
# TODO: This probably doesn't correctly also disable ProxySymDispatchMode
Expand Down
79 changes: 79 additions & 0 deletions torch/utils/weak.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import weakref
from collections.abc import MutableMapping
from typing import Dict


__all__ = ['WeakTensorRefKey', 'WeakTensorKeyDictionary']


# Utility classes for working with weak references to tensors

# torch.Tensors cannot be used as a key in a dictionary
# because they define a custom __eq__ function which when used
# to resolve hash collisions will throw when comparing tensors:
# "RuntimeError: bool value of Tensor with more than one value is ambiguous."
# To avoid that, we use an object which will hold a Tensor and use
# its id for both hashing and equality.
# In order to use this as a weak key reference, we cannot
# simply use weakref.WeakKeyDictionary because the newly constructed
# WeakTensorRefKey only use would be a dictionary so it would have no strong
# references.
# To get around this issue, we can use it as a normal key, and then set
# `weakref.finalize` to delete the key when its contained tensor dies.


class WeakTensorRefKey(object):
def __init__(self, ten):
self.ten = weakref.ref(ten)
# store id since as soon as ten is deallocated
# the old id will no longer be recoverable, and
# we need to be able to remove the WeakTensorRefKey
# from the dictionary by hashing it to the same
# value it had when ten was alive
self.id = id(self.ten())

def __hash__(self):
return self.id

def __eq__(self, other):
if id(self) == id(other):
return True
return self.id == other.id

class WeakTensorKeyDictionary(MutableMapping):
data: Dict[WeakTensorRefKey, object]

def __init__(self):
self.data = {}

def __contains__(self, k):
return WeakTensorRefKey(k) in self.data

def __len__(self):
return len(self.data)

def __iter__(self):
def generator():
for wk in self.data:
k = wk.ten()
if k is not None:
yield k
return generator()

def __getitem__(self, k):
return self.data[WeakTensorRefKey(k)]

def __setitem__(self, k, v):
wk = WeakTensorRefKey(k)
self_weak_ref = weakref.ref(self)

def del_ten():
self_ref = self_weak_ref()
if self_ref is None:
return
self_ref.data.pop(wk, None)
weakref.finalize(k, del_ten)
self.data[wk] = v

def __delitem__(self, k):
del self.data[WeakTensorRefKey(k)]

0 comments on commit 54563e6

Please sign in to comment.