|
22 | 22 | import numpy as np
|
23 | 23 |
|
24 | 24 | import dpctl
|
25 |
| -from dpctl.memory import MemoryUSMDevice, MemoryUSMHost, MemoryUSMShared |
| 25 | +from dpctl.memory import ( |
| 26 | + MemoryUSMDevice, |
| 27 | + MemoryUSMHost, |
| 28 | + MemoryUSMShared, |
| 29 | + as_usm_memory, |
| 30 | +) |
26 | 31 |
|
27 | 32 | from ._helper import has_cpu, has_gpu, has_sycl_platforms
|
28 | 33 |
|
@@ -249,18 +254,21 @@ def setUp(self):
|
249 | 254 |
|
250 | 255 |
|
251 | 256 | class View:
|
252 |
| - def __init__(self, buf, shape, strides, offset): |
253 |
| - self.buffer = buf |
254 |
| - self.shape = shape |
255 |
| - self.strides = strides |
256 |
| - self.offset = offset |
| 257 | + def __init__(self, buf, shape, strides, offset, syclobj=None): |
| 258 | + self.buffer_ = buf |
| 259 | + self.shape_ = shape |
| 260 | + self.strides_ = strides |
| 261 | + self.offset_ = offset |
| 262 | + self.syclobj_ = syclobj |
257 | 263 |
|
258 | 264 | @property
|
259 | 265 | def __sycl_usm_array_interface__(self):
|
260 |
| - sua_iface = self.buffer.__sycl_usm_array_interface__ |
261 |
| - sua_iface["offset"] = self.offset |
262 |
| - sua_iface["shape"] = self.shape |
263 |
| - sua_iface["strides"] = self.strides |
| 266 | + sua_iface = self.buffer_.__sycl_usm_array_interface__ |
| 267 | + sua_iface["offset"] = self.offset_ |
| 268 | + sua_iface["shape"] = self.shape_ |
| 269 | + sua_iface["strides"] = self.strides_ |
| 270 | + if self.syclobj_: |
| 271 | + sua_iface["syclobj"] = self.syclobj_ |
264 | 272 | return sua_iface
|
265 | 273 |
|
266 | 274 |
|
@@ -330,5 +338,74 @@ def test_suai_non_contig_2D(self):
|
330 | 338 | self.assertTrue(np.array_equal(res, expected_res))
|
331 | 339 |
|
332 | 340 |
|
| 341 | +class TestAsUSMMemory(unittest.TestCase): |
| 342 | + def _with_constructor(self, buffer_cls): |
| 343 | + try: |
| 344 | + buf = buffer_cls(64) |
| 345 | + except Exception: |
| 346 | + self.SkipTest( |
| 347 | + "{} could not be allocated".format(buffer_cls.__name__) |
| 348 | + ) |
| 349 | + # reuse queue from buffer's SUAI |
| 350 | + v = View(buf, shape=(64,), strides=(1,), offset=0) |
| 351 | + m = as_usm_memory(v) |
| 352 | + self.assertTrue(m.get_usm_type() == buf.get_usm_type()) |
| 353 | + self.assertTrue(m._pointer == buf._pointer) |
| 354 | + self.assertTrue(m.sycl_device == buf.sycl_device) |
| 355 | + # Use SyclContext |
| 356 | + v = View( |
| 357 | + buf, shape=(64,), strides=(1,), offset=0, syclobj=buf.sycl_context |
| 358 | + ) |
| 359 | + m = as_usm_memory(v) |
| 360 | + self.assertTrue(m.get_usm_type() == buf.get_usm_type()) |
| 361 | + self.assertTrue(m._pointer == buf._pointer) |
| 362 | + self.assertTrue(m.sycl_device == buf.sycl_device) |
| 363 | + # Use queue capsule |
| 364 | + v = View( |
| 365 | + buf, |
| 366 | + shape=(64,), |
| 367 | + strides=(1,), |
| 368 | + offset=0, |
| 369 | + syclobj=buf._queue._get_capsule(), |
| 370 | + ) |
| 371 | + m = as_usm_memory(v) |
| 372 | + self.assertTrue(m.get_usm_type() == buf.get_usm_type()) |
| 373 | + self.assertTrue(m._pointer == buf._pointer) |
| 374 | + self.assertTrue(m.sycl_device == buf.sycl_device) |
| 375 | + # Use context capsule |
| 376 | + v = View( |
| 377 | + buf, |
| 378 | + shape=(64,), |
| 379 | + strides=(1,), |
| 380 | + offset=0, |
| 381 | + syclobj=buf.sycl_context._get_capsule(), |
| 382 | + ) |
| 383 | + m = as_usm_memory(v) |
| 384 | + self.assertTrue(m.get_usm_type() == buf.get_usm_type()) |
| 385 | + self.assertTrue(m._pointer == buf._pointer) |
| 386 | + self.assertTrue(m.sycl_device == buf.sycl_device) |
| 387 | + # Use filter string |
| 388 | + v = View( |
| 389 | + buf, |
| 390 | + shape=(64,), |
| 391 | + strides=(1,), |
| 392 | + offset=0, |
| 393 | + syclobj=buf.sycl_device.filter_string, |
| 394 | + ) |
| 395 | + m = as_usm_memory(v) |
| 396 | + self.assertTrue(m.get_usm_type() == buf.get_usm_type()) |
| 397 | + self.assertTrue(m._pointer == buf._pointer) |
| 398 | + self.assertTrue(m.sycl_device == buf.sycl_device) |
| 399 | + |
| 400 | + def test_from_device(self): |
| 401 | + self._with_constructor(MemoryUSMDevice) |
| 402 | + |
| 403 | + def test_from_shared(self): |
| 404 | + self._with_constructor(MemoryUSMShared) |
| 405 | + |
| 406 | + def test_from_host(self): |
| 407 | + self._with_constructor(MemoryUSMHost) |
| 408 | + |
| 409 | + |
333 | 410 | if __name__ == "__main__":
|
334 | 411 | unittest.main()
|
0 commit comments