Skip to content

[WIP] [Data API] Backend refactoring #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
75c1d43
Refactoring of basic functionality to create an empty Array
Jan 24, 2023
b14aa91
Replace dim4 with CShape
roaffix Jan 24, 2023
eadbe9b
Add tests. Minor fixes. Update CI
roaffix Jan 24, 2023
c13a59f
Fix CI
roaffix Jan 24, 2023
f0f57e8
Add arithmetic operators w/o tests
roaffix Jan 26, 2023
8cef774
Fix array init bug. Add __getitem__. Change pytest for active debug mode
roaffix Jan 27, 2023
a4c7ac9
Add reflected arithmetic and array operators
roaffix Jan 27, 2023
4140527
Place TODO for repr
roaffix Jan 28, 2023
4374d93
Add bitwise operators. Add in-place operators. Add missing reflected …
roaffix Jan 28, 2023
5a29ffa
Fix tests
roaffix Jan 28, 2023
4187b27
Add tests for arithmetic operators
roaffix Jan 28, 2023
cdb7a92
Added to_list and to_ctypes_array
roaffix Jan 28, 2023
9c0435a
Fix bug when scalar is empty returns None
roaffix Jan 28, 2023
769c16c
Fix typing in array object. Add tests
roaffix Jan 29, 2023
fb27e46
Change tests and found bug with reflected operators
roaffix Jan 29, 2023
0afb92e
Fix reflected operators bug. Add test coverage for the rest of the ar…
roaffix Jan 29, 2023
1d071be
Add required by specification methods
roaffix Jan 30, 2023
04fbb1b
Change file names
roaffix Jan 30, 2023
2d91b04
Change utils. Add docstrings
roaffix Jan 30, 2023
5939388
Add docstrings for operators
roaffix Jan 30, 2023
0231e27
Change TODOs
roaffix Jan 30, 2023
07c4206
Add docstrings for other operators. Remove docstrings from mocks
roaffix Jan 30, 2023
908447b
Change tags and typings
roaffix Feb 4, 2023
fa3ad06
Change typings from python 3.10 to python 3.8
roaffix Feb 4, 2023
0de9955
Add readme with reference to run tests
roaffix Feb 4, 2023
ae6be05
Revert changes accidentally made in original array
roaffix Feb 5, 2023
cfa9114
Add initial refactoring with backend mock
roaffix Feb 8, 2023
5de8694
Add c library methods for operators
roaffix Feb 9, 2023
b9ac1c5
Remove dependency on default backend
roaffix Feb 10, 2023
171ec88
Refactor backend and project structure
roaffix Feb 10, 2023
e984caa
Refactor backend library operators
roaffix Feb 11, 2023
0b164d4
Refactor used in array_object backend methods
roaffix Feb 11, 2023
282f860
Minor test fix
roaffix Feb 11, 2023
54f7ada
Refactor tests
roaffix Feb 13, 2023
51f6efd
Add comparison operators tests
roaffix Feb 13, 2023
23e2635
Minor fixes for tests
roaffix Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix reflected operators bug. Add test coverage for the rest of the ar…
…ithmetic operators
  • Loading branch information
roaffix committed Jan 29, 2023
commit 0afb92eeb047e0b60cbe1eeeb05cf8182f8084fe
58 changes: 35 additions & 23 deletions arrayfire/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ._dtypes import int64 as af_int64
from ._dtypes import supported_dtypes
from ._dtypes import uint64 as af_uint64
from ._utils import PointerSource, is_number, to_str
from ._utils import PointerSource, to_str

ShapeType = tuple[int, ...]
_bcast_var = False # HACK, TODO replace for actual bcast_var after refactoring
Expand Down Expand Up @@ -286,25 +286,25 @@ def __radd__(self, other: Array, /) -> Array:
"""
Return other + self.
"""
return _process_c_function(self, other, backend.get().af_add)
return _process_c_function(other, self, backend.get().af_add)

def __rsub__(self, other: Array, /) -> Array:
"""
Return other - self.
"""
return _process_c_function(self, other, backend.get().af_sub)
return _process_c_function(other, self, backend.get().af_sub)

def __rmul__(self, other: Array, /) -> Array:
"""
Return other * self.
"""
return _process_c_function(self, other, backend.get().af_mul)
return _process_c_function(other, self, backend.get().af_mul)

def __rtruediv__(self, other: Array, /) -> Array:
"""
Return other / self.
"""
return _process_c_function(self, other, backend.get().af_div)
return _process_c_function(other, self, backend.get().af_div)

def __rfloordiv__(self, other: Array, /) -> Array:
# TODO
Expand All @@ -314,13 +314,13 @@ def __rmod__(self, other: Array, /) -> Array:
"""
Return other / self.
"""
return _process_c_function(self, other, backend.get().af_mod)
return _process_c_function(other, self, backend.get().af_mod)

def __rpow__(self, other: Array, /) -> Array:
"""
Return other ** self.
"""
return _process_c_function(self, other, backend.get().af_pow)
return _process_c_function(other, self, backend.get().af_pow)

# Reflected Array Operators

Expand All @@ -334,31 +334,31 @@ def __rand__(self, other: Array, /) -> Array:
"""
Return other & self.
"""
return _process_c_function(self, other, backend.get().af_bitand)
return _process_c_function(other, self, backend.get().af_bitand)

def __ror__(self, other: Array, /) -> Array:
"""
Return other & self.
"""
return _process_c_function(self, other, backend.get().af_bitor)
return _process_c_function(other, self, backend.get().af_bitor)

def __rxor__(self, other: Array, /) -> Array:
"""
Return other ^ self.
"""
return _process_c_function(self, other, backend.get().af_bitxor)
return _process_c_function(other, self, backend.get().af_bitxor)

def __rlshift__(self, other: Array, /) -> Array:
"""
Return other << self.
"""
return _process_c_function(self, other, backend.get().af_bitshiftl)
return _process_c_function(other, self, backend.get().af_bitshiftl)

def __rrshift__(self, other: Array, /) -> Array:
"""
Return other >> self.
"""
return _process_c_function(self, other, backend.get().af_bitshiftr)
return _process_c_function(other, self, backend.get().af_bitshiftr)

# In-place Arithmetic Operators

Expand Down Expand Up @@ -614,20 +614,32 @@ def _str_to_dtype(value: int) -> Dtype:


def _process_c_function(
target: Array, other: int | float | bool | complex | Array, c_function: Any) -> Array:
lhs: int | float | bool | complex | Array, rhs: int | float | bool | complex | Array,
c_function: Any) -> Array:
out = Array()

# TODO discuss the difference between binary_func and binary_funcr
# because implementation looks like exectly the same.
# consider chaging to __iadd__ = __radd__ = __add__ interfce if no difference
if isinstance(other, Array):
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other.arr, _bcast_var))
elif is_number(other):
other_dtype = _implicit_dtype(other, target.dtype)
other_array = _constant_array(other, CShape(*target.shape), other_dtype)
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other_array.arr, _bcast_var))
if isinstance(lhs, Array) and isinstance(rhs, Array):
lhs_array = lhs.arr
rhs_array = rhs.arr

elif isinstance(lhs, Array) and isinstance(rhs, int | float | bool | complex):
rhs_dtype = _implicit_dtype(rhs, lhs.dtype)
rhs_constant_array = _constant_array(rhs, CShape(*lhs.shape), rhs_dtype)

lhs_array = lhs.arr
rhs_array = rhs_constant_array.arr

elif isinstance(lhs, int | float | bool | complex) and isinstance(rhs, Array):
lhs_dtype = _implicit_dtype(lhs, rhs.dtype)
lhs_constant_array = _constant_array(lhs, CShape(*rhs.shape), lhs_dtype)

lhs_array = lhs_constant_array.arr
rhs_array = rhs.arr

else:
raise TypeError(f"{type(other)} is not supported and can not be passed to C binary function.")
raise TypeError(f"{type(rhs)} is not supported and can not be passed to C binary function.")

safe_call(c_function(ctypes.pointer(out.arr), lhs_array, rhs_array, _bcast_var))

return out

Expand Down
5 changes: 0 additions & 5 deletions arrayfire/array_api/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ctypes
import enum
import numbers


class PointerSource(enum.Enum):
Expand All @@ -14,7 +13,3 @@ class PointerSource(enum.Enum):

def to_str(c_str: ctypes.c_char_p) -> str:
return str(c_str.value.decode("utf-8")) # type: ignore[union-attr]


def is_number(number: int | float | bool | complex) -> bool:
return isinstance(number, numbers.Number)
17 changes: 9 additions & 8 deletions arrayfire/array_api/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,6 @@ def setup_method(self, method: Any) -> None:
self.tuple = (1, 2, 3)
self.const_str = "15"

def teardown_method(self, method: Any) -> None:
self.array = Array(self.list)

def test_add_int(self) -> None:
res = self.array + self.const_int
assert res[0].scalar() == 3
Expand Down Expand Up @@ -220,10 +217,10 @@ def test_add_inplace_and_reflected(self) -> None:

def test_add_raises_type_error(self) -> None:
with pytest.raises(TypeError):
Array([1, 2, 3]) + self.const_str # type: ignore[operator]
self.array + self.const_str # type: ignore[operator]

with pytest.raises(TypeError):
Array([1, 2, 3]) + self.tuple # type: ignore[operator]
self.array + self.tuple # type: ignore[operator]

# Test __sub__, __isub__, __rsub__

Expand Down Expand Up @@ -251,9 +248,13 @@ def test_sub_inplace_and_reflected(self) -> None:
ires -= self.const_int
rres = self.const_int - self.array # type: ignore[operator]

assert res[0].scalar() == ires[0].scalar() == rres[0].scalar() == -1
assert res[1].scalar() == ires[1].scalar() == rres[1].scalar() == 0
assert res[2].scalar() == ires[2].scalar() == rres[2].scalar() == 1
assert res[0].scalar() == ires[0].scalar() == -1
assert res[1].scalar() == ires[1].scalar() == 0
assert res[2].scalar() == ires[2].scalar() == 1

assert rres[0].scalar() == 1
assert rres[1].scalar() == 0
assert rres[2].scalar() == -1

assert res.dtype == ires.dtype == rres.dtype
assert res.ndim == ires.ndim == rres.ndim
Expand Down