|
10 | 10 | skipMeta,
|
11 | 11 | onlyCPU
|
12 | 12 | )
|
| 13 | +from torch.testing._internal.common_dtype import floating_types_and_half |
13 | 14 | from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state, parametrize, gradcheck
|
14 | 15 | from torch import nested_tensor
|
15 | 16 |
|
@@ -92,12 +93,6 @@ def test_unbind_1(self):
|
92 | 93 | torch.tensor([1]), torch.tensor([7]),
|
93 | 94 | )
|
94 | 95 |
|
95 |
| - # @torch.inference_mode() |
96 |
| - # def test_unbind_2(self): |
97 |
| - # self._test_unbind_case( |
98 |
| - # torch.tensor(1), torch.tensor(7), |
99 |
| - # ) |
100 |
| - |
101 | 96 | @torch.inference_mode()
|
102 | 97 | def test_unbind_3(self):
|
103 | 98 | self._test_unbind_case(
|
@@ -302,6 +297,36 @@ def random_nt_pair(self, device, dtype, num_tensors, max_dims):
|
302 | 297 | return (torch.nested_tensor(ts1, device=device, dtype=dtype),
|
303 | 298 | torch.nested_tensor(ts2, device=device, dtype=dtype))
|
304 | 299 |
|
| 300 | + @dtypes(*floating_types_and_half()) |
| 301 | + @dtypesIfCUDA(torch.float64) |
| 302 | + def test_detach(self, device, dtype): |
| 303 | + a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False) |
| 304 | + b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False) |
| 305 | + x = torch.nested_tensor([a, b]).requires_grad_() |
| 306 | + |
| 307 | + x_detach = x.detach() |
| 308 | + |
| 309 | + z = x_detach * 4 |
| 310 | + self.assertFalse(x_detach.requires_grad) |
| 311 | + self.assertFalse(z.requires_grad) |
| 312 | + |
| 313 | + a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True) |
| 314 | + b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True) |
| 315 | + x = torch.nested_tensor([a, b]) |
| 316 | + |
| 317 | + y = x * 2 |
| 318 | + y = y.detach() |
| 319 | + self.assertFalse(y.requires_grad) |
| 320 | + self.assertIsNone(y.grad_fn) |
| 321 | + |
| 322 | + z = x + y |
| 323 | + z.to_padded_tensor(0).sum().backward() |
| 324 | + # This is an incorrect gradient, but we assume that's what the user |
| 325 | + # wanted. detach() is an advanced option. |
| 326 | + self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype)) |
| 327 | + self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype)) |
| 328 | + |
| 329 | + |
305 | 330 | @dtypes(torch.float, torch.float16, torch.double)
|
306 | 331 | def test_unbind_noncontiguous(self, device, dtype):
|
307 | 332 | nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)
|
|
0 commit comments