Skip to content

Commit 21f1cca

Browse files
Skip prod tests for complex output types on Gen9
1 parent 662bc45 commit 21f1cca

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

dpctl/tests/test_tensor_sum.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,18 @@ def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype):
245245

246246
out_dtype = dpt.dtype(out_dtype)
247247
arg_dtype = dpt.dtype(arg_dtype)
248-
if dpt.isdtype(out_dtype, "complex floating") and du._is_gen9(
249-
q.sycl_device
250-
):
251-
pytest.skip(
252-
"Product reduction for complex output are known "
253-
"to fail for Gen9 with 2024.0 compiler"
248+
any_complex = any(
249+
dpt.isdtype(dt, "complex floating") for dt in (arg_dtype, out_dtype)
250+
)
251+
if any_complex:
252+
device_mask = (
253+
du.intel_device_info(q.sycl_device).get("device_id", 0) & 0xFF00
254254
)
255+
if device_mask in [0x3E00, 0x9B00]:
256+
pytest.skip(
257+
"Product reduction for complex output are known "
258+
"to fail for Gen9 with 2024.0 compiler"
259+
)
255260

256261
m = dpt.ones(100, dtype=arg_dtype)
257262
r = dpt.prod(m, dtype=out_dtype)

0 commit comments

Comments
 (0)