From dc20b2d9e01baab943d28bfe10092e3df0710ef4 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 30 Oct 2024 13:46:50 -0300 Subject: [PATCH] DLPack: add test using PyTorch DLPack functions. (#8294) Co-authored-by: iefgnoix --- .torch_pin | 1 + test/test_operations.py | 11 +++++++++++ 2 files changed, 12 insertions(+) create mode 100644 .torch_pin diff --git a/.torch_pin b/.torch_pin new file mode 100644 index 00000000000..9eb602820f8 --- /dev/null +++ b/.torch_pin @@ -0,0 +1 @@ +#138470 diff --git a/test/test_operations.py b/test/test_operations.py index 1af928e6a47..545831acf49 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2912,6 +2912,17 @@ def test_dlpack_xla_to_pytorch_cuda(self): cuda_t1[0] = cuda_t1[0] + 20 self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_dlpack_xla_to_pytorch_cuda_protocol_conversion(self): + xla_t1 = torch.arange(5).to(xm.xla_device()) + caps_t1 = torch.utils.dlpack.to_dlpack(xla_t1) + cuda_t1 = torch.utils.dlpack.from_dlpack(caps_t1) + self.assertEqual(cuda_t1.device.type, 'cuda') + self.assertEqual(cuda_t1.device.index, xla_t1.device.index) + cuda_t1[0] = cuda_t1[0] + 20 + self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) + @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_non_default_layout(self):