diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index eff49a5a0d..d700e7a96d 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -260,5 +260,3 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) - - reset_tmp_dir() diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index e5ef955930..5633fba0cc 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -2,7 +2,6 @@ import os import shutil -import pytest import torch import triton @@ -38,10 +37,11 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: - pytest.skip("FIXME: Port get_device_capability to XPU") + cc = 0 + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor - major, minor = torch.cuda.get_device_capability(0) - cc = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ()) multiprocessing.set_start_method('fork') @@ -65,11 +65,11 @@ def kernel_dot(Z): def test_compile_in_forked_subproc() -> None: - pytest.skip("FIXME: Port get_device_capability to XPU") + capability = 0 + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability(0) + capability = major * 10 + minor - reset_tmp_dir() - major, minor = torch.cuda.get_device_capability(0) - capability = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(1)), (), (), ()) assert multiprocessing.get_start_method() == 'fork' @@ -77,3 +77,5 @@ def test_compile_in_forked_subproc() -> None: proc.start() proc.join() assert proc.exitcode == 0 + + reset_tmp_dir()