Skip to content

Commit 90edf4f

Browse files
committed
Added support of AFArray indexing
1 parent 983500e commit 90edf4f

File tree

1 file changed

+12
-17
lines changed
  • arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing

1 file changed

+12
-17
lines changed

arrayfire_wrapper/lib/create_and_modify_array/assignment_and_indexing/_indexing.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from typing import Any
66

77
from arrayfire_wrapper.lib._broadcast import bcast_var
8-
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import release_array
9-
8+
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import release_array, retain_array
9+
from arrayfire_wrapper.defines import AFArray
1010

1111
class _IndexSequence(ctypes.Structure):
1212
"""
@@ -186,7 +186,7 @@ class IndexStructure(ctypes.Structure):
186186
-----------
187187
188188
idx: key
189-
- If of type af.Array, self.idx.arr = idx, self.isSeq = False
189+
- If of type AFArray, self.idx.arr = idx, self.isSeq = False
190190
- If of type af.ParallelRange, self.idx.seq = idx, self.isBatch = True
191191
- Default:, self.idx.seq = af._IndexSequence(idx)
192192
@@ -197,26 +197,21 @@ class IndexStructure(ctypes.Structure):
197197
198198
"""
199199

200-
def __init__(self, idx: Any) -> None:
200+
def __init__(self, idx: int | slice | AFArray) -> None:
201201
self.idx = _IndexUnion()
202202
self.isBatch = False
203203
self.isSeq = True
204204

205-
# BUG cyclic reimport
206-
# if isinstance(idx, Array):
207-
# if idx.dtype == af_bool:
208-
# self.idx.arr = everything.where(idx.arr)
209-
# else:
210-
# self.idx.arr = everything.retain_array(idx.arr)
211-
212-
# self.isSeq = False
213-
214-
if isinstance(idx, ParallelRange):
205+
if isinstance(idx, int) or isinstance(idx, slice):
206+
self.idx.seq = _IndexSequence(idx)
207+
elif isinstance(idx, ParallelRange):
215208
self.idx.seq = idx
216209
self.isBatch = True
217-
210+
elif isinstance(idx, AFArray):
211+
self.idx.arr = retain_array(idx)
212+
self.isSeq = False
218213
else:
219-
self.idx.seq = _IndexSequence(idx)
214+
raise IndexError("Invalid type while indexing arrayfire.array")
220215

221216
def __del__(self) -> None:
222217
if not self.isSeq:
@@ -247,7 +242,7 @@ def __setitem__(self, idx: int, value: IndexStructure) -> None:
247242
self.idxs[idx] = value
248243

249244

250-
def get_indices(key: int | slice | tuple[int | slice, ...]) -> CIndexStructure: # BUG
245+
def get_indices(key: int | slice | tuple[int | slice | AFArray, ...] | AFArray) -> CIndexStructure: # BUG
251246
indices = CIndexStructure()
252247

253248
if isinstance(key, tuple):

0 commit comments

Comments
 (0)