|
18 | 18 |
|
19 | 19 | import os |
20 | 20 |
|
| 21 | +import numpy as np |
21 | 22 | import pytest |
22 | 23 |
|
23 | 24 | import dpctl |
| 25 | +import dpctl.memory as dpm |
24 | 26 | import dpctl.program as dpctl_prog |
25 | 27 |
|
26 | 28 |
|
| 29 | +def _get_opencl_queue_or_skip(): |
| 30 | + try: |
| 31 | + return dpctl.SyclQueue("opencl") |
| 32 | + except dpctl.SyclQueueCreationError: |
| 33 | + pytest.skip("No OpenCL queue is available") |
| 34 | + |
| 35 | + |
| 36 | +def _skip_if_no_sycl_source_compilation(q): |
| 37 | + if not dpctl.program.is_sycl_source_compilation_available(): |
| 38 | + pytest.skip("SYCL source compilation extension not available") |
| 39 | + if not q.get_sycl_device().can_compile("sycl"): |
| 40 | + pytest.skip("SYCL source compilation not supported") |
| 41 | + |
| 42 | + |
27 | 43 | def get_spirv_abspath(fn): |
28 | 44 | curr_dir = os.path.dirname(os.path.abspath(__file__)) |
29 | 45 | spirv_file = os.path.join(curr_dir, "input_files", fn) |
@@ -266,13 +282,8 @@ def test_create_program_from_invalid_src_ocl(): |
266 | 282 |
|
267 | 283 |
|
268 | 284 | def test_create_program_from_sycl_source(): |
269 | | - try: |
270 | | - q = dpctl.SyclQueue("opencl") |
271 | | - except dpctl.SyclQueueCreationError: |
272 | | - pytest.skip("No OpenCL queue is available") |
273 | | - |
274 | | - if not q.get_sycl_device().can_compile("sycl"): |
275 | | - pytest.skip("SYCL source compilation not supported") |
| 285 | + q = _get_opencl_queue_or_skip() |
| 286 | + _skip_if_no_sycl_source_compilation(q) |
276 | 287 |
|
277 | 288 | sycl_source = """ |
278 | 289 | #include <sycl/sycl.hpp> |
@@ -376,13 +387,8 @@ def test_create_program_from_sycl_source(): |
376 | 387 |
|
377 | 388 |
|
378 | 389 | def test_create_program_from_invalid_src_sycl(): |
379 | | - try: |
380 | | - q = dpctl.SyclQueue("opencl") |
381 | | - except dpctl.SyclQueueCreationError: |
382 | | - pytest.skip("No OpenCL queue is available") |
383 | | - |
384 | | - if not q.get_sycl_device().can_compile("sycl"): |
385 | | - pytest.skip("SYCL source compilation not supported") |
| 390 | + q = _get_opencl_queue_or_skip() |
| 391 | + _skip_if_no_sycl_source_compilation(q) |
386 | 392 |
|
387 | 393 | sycl_source = """ |
388 | 394 | #include <sycl/sycl.hpp> |
@@ -410,3 +416,75 @@ def test_create_program_from_invalid_src_sycl(): |
410 | 416 | except dpctl_prog.SyclProgramCompilationError as prog_error: |
411 | 417 | print(str(prog_error)) |
412 | 418 | assert "error: expected ';' at end of declaration" in str(prog_error) |
| 419 | + |
| 420 | + |
| 421 | +def test_sycl_source_compilation_is_available_returns_bool(): |
| 422 | + v = dpctl.program.is_sycl_source_compilation_available() |
| 423 | + assert type(v) is bool |
| 424 | + |
| 425 | + |
| 426 | +def test_sycl_source_vector_add_correctness(): |
| 427 | + q = _get_opencl_queue_or_skip() |
| 428 | + _skip_if_no_sycl_source_compilation(q) |
| 429 | + |
| 430 | + sycl_source = """ |
| 431 | + #include <sycl/sycl.hpp> |
| 432 | + #include "math_ops.hpp" |
| 433 | +
|
| 434 | + namespace syclext = sycl::ext::oneapi::experimental; |
| 435 | +
|
| 436 | + extern "C" SYCL_EXTERNAL |
| 437 | + SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclext::nd_range_kernel<1>)) |
| 438 | + void vector_add(int* in1, int* in2, int* out){ |
| 439 | + sycl::nd_item<1> item = |
| 440 | + sycl::ext::oneapi::this_work_item::get_nd_item<1>(); |
| 441 | + size_t globalID = item.get_global_linear_id(); |
| 442 | + out[globalID] = math_op(in1[globalID], in2[globalID]); |
| 443 | + } |
| 444 | + """ |
| 445 | + |
| 446 | + header_content = """ |
| 447 | + int math_op(int a, int b){ |
| 448 | + return a + b; |
| 449 | + } |
| 450 | + """ |
| 451 | + |
| 452 | + prog = dpctl.program.create_program_from_sycl_source( |
| 453 | + q, |
| 454 | + sycl_source, |
| 455 | + headers=[("math_ops.hpp", header_content)], |
| 456 | + registered_names=[], |
| 457 | + copts=["-fno-fast-math"], |
| 458 | + ) |
| 459 | + |
| 460 | + kernel = prog.get_sycl_kernel("vector_add") |
| 461 | + |
| 462 | + local_size = 16 |
| 463 | + global_size = local_size * 8 |
| 464 | + |
| 465 | + in1 = np.arange(global_size, dtype=np.int32) |
| 466 | + in2 = (np.arange(global_size, dtype=np.int32) * 3 - 7).astype(np.int32) |
| 467 | + out = np.empty(global_size, dtype=np.int32) |
| 468 | + expected = (in1 + in2).astype(np.int32) |
| 469 | + |
| 470 | + in1_usm = dpm.MemoryUSMDevice(in1.nbytes, queue=q) |
| 471 | + in2_usm = dpm.MemoryUSMDevice(in2.nbytes, queue=q) |
| 472 | + out_usm = dpm.MemoryUSMDevice(out.nbytes, queue=q) |
| 473 | + |
| 474 | + ev1 = q.memcpy_async(dest=in1_usm, src=in1, count=in1.nbytes) |
| 475 | + ev2 = q.memcpy_async(dest=in2_usm, src=in2, count=in2.nbytes) |
| 476 | + |
| 477 | + try: |
| 478 | + ev3 = q.submit( |
| 479 | + kernel, |
| 480 | + [in1_usm, in2_usm, out_usm], |
| 481 | + [global_size], |
| 482 | + [local_size], |
| 483 | + dEvents=[ev1, ev2], |
| 484 | + ) |
| 485 | + except dpctl._sycl_queue.SyclKernelSubmitError: |
| 486 | + pytest.skip(f"Kernel submission to {q.sycl_device} failed") |
| 487 | + |
| 488 | + ev4 = q.memcpy_async(dest=out, src=out_usm, count=out.nbytes, dEvents=[ev3]) |
| 489 | + ev4.wait() |
| 490 | + assert np.array_equal(out, expected) |
0 commit comments