Skip to content

BUG: add missing _check_type_device calls #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,7 @@ def __imod__(self, other: Array | float, /) -> Array:
"""
Performs the operation __imod__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
if other is NotImplemented:
return other
Expand All @@ -1126,6 +1127,7 @@ def __imul__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __imul__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__imul__")
if other is NotImplemented:
return other
Expand All @@ -1148,6 +1150,7 @@ def __ior__(self, other: Array | int, /) -> Array:
"""
Performs the operation __ior__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__")
if other is NotImplemented:
return other
Expand All @@ -1170,6 +1173,7 @@ def __ipow__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __ipow__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
if other is NotImplemented:
return other
Expand All @@ -1182,6 +1186,7 @@ def __rpow__(self, other: Array | complex, /) -> Array:
"""
from ._elementwise_functions import pow # type: ignore[attr-defined]

self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
Expand All @@ -1193,6 +1198,7 @@ def __irshift__(self, other: Array | int, /) -> Array:
"""
Performs the operation __irshift__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer", "__irshift__")
if other is NotImplemented:
return other
Expand All @@ -1215,6 +1221,7 @@ def __isub__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __isub__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__isub__")
if other is NotImplemented:
return other
Expand All @@ -1237,6 +1244,7 @@ def __itruediv__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __itruediv__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__")
if other is NotImplemented:
return other
Expand All @@ -1259,6 +1267,7 @@ def __ixor__(self, other: Array | int, /) -> Array:
"""
Performs the operation __ixor__.
"""
self._check_type_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__")
if other is NotImplemented:
return other
Expand Down
7 changes: 7 additions & 0 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,13 @@ def _array_vals():
getattr(x, _op)(y)
else:
assert_raises(TypeError, lambda: getattr(x, _op)(y))
# finally, test that array op ndarray raises
# XXX: as long as there is __array__ or __buffer__, __rop__s
# still return ndarrays
if not _op.startswith("__r"):
with assert_raises(TypeError):
getattr(x, _op)(y._array)


for op, dtypes in unary_op_dtypes.items():
for a in _array_vals():
Expand Down
Loading