Skip to content

Commit db8fe2f

Browse files
Merge pull request #1048 from IntelPython/make-dpctl-tensor-Device-hashable
Make dpctl tensor device hashable, added equality comparison
2 parents b17589e + 3dda4f4 commit db8fe2f

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

dpctl/tensor/_device.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,18 @@ def wait(self):
124124
"""
125125
self.sycl_queue_.wait()
126126

127+
def __eq__(self, other):
128+
"""Equality comparison based on underlying ``sycl_queue``."""
129+
if isinstance(other, Device):
130+
return self.sycl_queue.__eq__(other.sycl_queue)
131+
elif isinstance(other, dpctl.SyclQueue):
132+
return self.sycl_queue.__eq__(other)
133+
return False
134+
135+
def __hash__(self):
136+
"""Compute object's hash value."""
137+
return self.sycl_queue.__hash__()
138+
127139

128140
def normalize_queue_device(sycl_queue=None, device=None):
129141
"""

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,3 +1575,19 @@ def test_asarray_uint64():
15751575
Xnp = np.ndarray(1, dtype=np.uint64)
15761576
X = dpt.asarray(Xnp)
15771577
assert X.dtype == Xnp.dtype
1578+
1579+
1580+
def test_Device():
1581+
try:
1582+
dev = dpctl.select_default_device()
1583+
d1 = dpt.Device.create_device(dev)
1584+
d2 = dpt.Device.create_device(dev)
1585+
except (dpctl.SyclQueueCreationError, dpctl.SyclDeviceCreationError):
1586+
pytest.skip(
1587+
"Could not create default device, or a queue that targets it"
1588+
)
1589+
assert d1 == d2
1590+
dict = {d1: 1}
1591+
assert dict[d2] == 1
1592+
assert d1 == d2.sycl_queue
1593+
assert not d1 == Ellipsis

0 commit comments

Comments
 (0)