|
10 | 10 | skipMeta,
|
11 | 11 | onlyCPU
|
12 | 12 | )
|
13 |
| -from torch.testing._internal.common_dtype import floating_types_and_half |
14 | 13 | from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state, parametrize, gradcheck
|
15 | 14 | from torch import nested_tensor
|
16 | 15 |
|
@@ -93,6 +92,12 @@ def test_unbind_1(self):
|
93 | 92 | torch.tensor([1]), torch.tensor([7]),
|
94 | 93 | )
|
95 | 94 |
|
| 95 | + # @torch.inference_mode() |
| 96 | + # def test_unbind_2(self): |
| 97 | + # self._test_unbind_case( |
| 98 | + # torch.tensor(1), torch.tensor(7), |
| 99 | + # ) |
| 100 | + |
96 | 101 | @torch.inference_mode()
|
97 | 102 | def test_unbind_3(self):
|
98 | 103 | self._test_unbind_case(
|
@@ -297,36 +302,6 @@ def random_nt_pair(self, device, dtype, num_tensors, max_dims):
|
297 | 302 | return (torch.nested_tensor(ts1, device=device, dtype=dtype),
|
298 | 303 | torch.nested_tensor(ts2, device=device, dtype=dtype))
|
299 | 304 |
|
300 |
| - @dtypes(*floating_types_and_half()) |
301 |
| - @dtypesIfCUDA(torch.float64) |
302 |
| - def test_detach(self, device, dtype): |
303 |
| - a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False) |
304 |
| - b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False) |
305 |
| - x = torch.nested_tensor([a, b]).requires_grad_() |
306 |
| - |
307 |
| - x_detach = x.detach() |
308 |
| - |
309 |
| - z = x_detach * 4 |
310 |
| - self.assertFalse(x_detach.requires_grad) |
311 |
| - self.assertFalse(z.requires_grad) |
312 |
| - |
313 |
| - a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True) |
314 |
| - b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True) |
315 |
| - x = torch.nested_tensor([a, b]) |
316 |
| - |
317 |
| - y = x * 2 |
318 |
| - y = y.detach() |
319 |
| - self.assertFalse(y.requires_grad) |
320 |
| - self.assertIsNone(y.grad_fn) |
321 |
| - |
322 |
| - z = x + y |
323 |
| - z.to_padded_tensor(0).sum().backward() |
324 |
| - # This is an incorrect gradient, but we assume that's what the user |
325 |
| - # wanted. detach() is an advanced option. |
326 |
| - self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype)) |
327 |
| - self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype)) |
328 |
| - |
329 |
| - |
330 | 305 | @dtypes(torch.float, torch.float16, torch.double)
|
331 | 306 | def test_unbind_noncontiguous(self, device, dtype):
|
332 | 307 | nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)
|
|
0 commit comments