Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 4b212e1

Browse files
authored
FakeTensor should be wrapped as TensorVariable (#931)
1 parent aa17b78 commit 4b212e1

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

test/test_misc.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,21 @@ def fn(x, m):
20602060
res = opt_fn(x, m)
20612061
self.assertEqual(ref, res)
20622062

2063+
def test_tensor_dot_grad_no_graph_break(self):
2064+
def fn(a, b):
2065+
y = 3 * a**3 - b**2
2066+
y.backward(gradient=torch.tensor([1.0, 1.0]))
2067+
b.grad.zero_()
2068+
return a.grad, b.grad
2069+
2070+
a = torch.tensor([2.0, 3.0], requires_grad=True)
2071+
b = torch.tensor([6.0, 4.0], requires_grad=True)
2072+
cnts = torchdynamo.testing.CompileCounter()
2073+
with torchdynamo.optimize(cnts):
2074+
_, b_grad = fn(a, b)
2075+
self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0])))
2076+
self.assertEqual(cnts.frame_count, 2)
2077+
20632078
def test_change_backends(self):
20642079
@torchdynamo.optimize("eager", nopython=True)
20652080
def fn1():

torchdynamo/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,14 @@ def is_numpy_float_type(value):
188188

189189
def istensor(obj):
190190
"""Check of obj is a tensor"""
191-
return istype(
192-
obj, (torch.Tensor, torch.nn.Parameter, *config.traceable_tensor_subclasses)
191+
tensor_list = (
192+
torch.Tensor,
193+
torch.nn.Parameter,
194+
*config.traceable_tensor_subclasses,
193195
)
196+
if fake_tensors_available:
197+
tensor_list = tensor_list + (torch._subclasses.FakeTensor,)
198+
return istype(obj, tensor_list)
194199

195200

196201
def is_lazy_module(mod):

0 commit comments

Comments
 (0)