Skip to content

Commit 092fe71

Browse files
drisspgpytorchmergebot
authored andcommitted
[Nested Tensor] detach (pytorch#84078)
## Summary Add detach op for nested tensors. Nested tensors are not part of the composite explicit dispatch key set and therefore need to be added manually. The Detach test is failing only for the dtype=torch.float32, torch.float16 and device=cuda. The chain of ops that called are sum.backward() -> from_padded() -> unbind(). This populates the grad for a and b. Does this potentially indicated that cuda implementation for one of these ops, likely from_padded() is incorrect? Pull Request resolved: pytorch#84078 Approved by: https://github.com/albanD
1 parent 43620b7 commit 092fe71

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4667,6 +4667,7 @@
46674667
variants: function, method
46684668
dispatch:
46694669
CompositeExplicitAutograd: detach
4670+
NestedTensorCPU, NestedTensorCUDA: detach
46704671

46714672
# Like `detach()`, but modifies this `Variable` in-place. This method may
46724673
# only be called on non-view `Variable`s. You can use `is_view()` to check

test/test_nestedtensor.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
skipMeta,
1111
onlyCPU
1212
)
13+
from torch.testing._internal.common_dtype import floating_types_and_half
1314
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state, parametrize, gradcheck
1415
from torch import nested_tensor
1516

@@ -92,12 +93,6 @@ def test_unbind_1(self):
9293
torch.tensor([1]), torch.tensor([7]),
9394
)
9495

95-
# @torch.inference_mode()
96-
# def test_unbind_2(self):
97-
# self._test_unbind_case(
98-
# torch.tensor(1), torch.tensor(7),
99-
# )
100-
10196
@torch.inference_mode()
10297
def test_unbind_3(self):
10398
self._test_unbind_case(
@@ -302,6 +297,36 @@ def random_nt_pair(self, device, dtype, num_tensors, max_dims):
302297
return (torch.nested_tensor(ts1, device=device, dtype=dtype),
303298
torch.nested_tensor(ts2, device=device, dtype=dtype))
304299

300+
@dtypes(*floating_types_and_half())
301+
@dtypesIfCUDA(torch.float64)
302+
def test_detach(self, device, dtype):
303+
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False)
304+
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False)
305+
x = torch.nested_tensor([a, b]).requires_grad_()
306+
307+
x_detach = x.detach()
308+
309+
z = x_detach * 4
310+
self.assertFalse(x_detach.requires_grad)
311+
self.assertFalse(z.requires_grad)
312+
313+
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True)
314+
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True)
315+
x = torch.nested_tensor([a, b])
316+
317+
y = x * 2
318+
y = y.detach()
319+
self.assertFalse(y.requires_grad)
320+
self.assertIsNone(y.grad_fn)
321+
322+
z = x + y
323+
z.to_padded_tensor(0).sum().backward()
324+
# This is an incorrect gradient, but we assume that's what the user
325+
# wanted. detach() is an advanced option.
326+
self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
327+
self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))
328+
329+
305330
@dtypes(torch.float, torch.float16, torch.double)
306331
def test_unbind_noncontiguous(self, device, dtype):
307332
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)

0 commit comments

Comments
 (0)