Skip to content

Commit f4f54c7

Browse files
Revert "[Nested Tensor] detach (pytorch#84078)"
This reverts commit 092fe71. Reverted pytorch#84078 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
1 parent 5cf4542 commit f4f54c7

File tree

2 files changed

+6
-32
lines changed

2 files changed

+6
-32
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4667,7 +4667,6 @@
46674667
variants: function, method
46684668
dispatch:
46694669
CompositeExplicitAutograd: detach
4670-
NestedTensorCPU, NestedTensorCUDA: detach
46714670

46724671
# Like `detach()`, but modifies this `Variable` in-place. This method may
46734672
# only be called on non-view `Variable`s. You can use `is_view()` to check

test/test_nestedtensor.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
skipMeta,
1111
onlyCPU
1212
)
13-
from torch.testing._internal.common_dtype import floating_types_and_half
1413
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state, parametrize, gradcheck
1514
from torch import nested_tensor
1615

@@ -93,6 +92,12 @@ def test_unbind_1(self):
9392
torch.tensor([1]), torch.tensor([7]),
9493
)
9594

95+
# @torch.inference_mode()
96+
# def test_unbind_2(self):
97+
# self._test_unbind_case(
98+
# torch.tensor(1), torch.tensor(7),
99+
# )
100+
96101
@torch.inference_mode()
97102
def test_unbind_3(self):
98103
self._test_unbind_case(
@@ -297,36 +302,6 @@ def random_nt_pair(self, device, dtype, num_tensors, max_dims):
297302
return (torch.nested_tensor(ts1, device=device, dtype=dtype),
298303
torch.nested_tensor(ts2, device=device, dtype=dtype))
299304

300-
@dtypes(*floating_types_and_half())
301-
@dtypesIfCUDA(torch.float64)
302-
def test_detach(self, device, dtype):
303-
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False)
304-
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False)
305-
x = torch.nested_tensor([a, b]).requires_grad_()
306-
307-
x_detach = x.detach()
308-
309-
z = x_detach * 4
310-
self.assertFalse(x_detach.requires_grad)
311-
self.assertFalse(z.requires_grad)
312-
313-
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True)
314-
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True)
315-
x = torch.nested_tensor([a, b])
316-
317-
y = x * 2
318-
y = y.detach()
319-
self.assertFalse(y.requires_grad)
320-
self.assertIsNone(y.grad_fn)
321-
322-
z = x + y
323-
z.to_padded_tensor(0).sum().backward()
324-
# This is an incorrect gradient, but we assume that's what the user
325-
# wanted. detach() is an advanced option.
326-
self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
327-
self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))
328-
329-
330305
@dtypes(torch.float, torch.float16, torch.double)
331306
def test_unbind_noncontiguous(self, device, dtype):
332307
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)

0 commit comments

Comments
 (0)