5
5
from typing import Any
6
6
7
7
from arrayfire_wrapper .lib ._broadcast import bcast_var
8
- from arrayfire_wrapper .lib .create_and_modify_array .manage_array import release_array
9
-
8
+ from arrayfire_wrapper .lib .create_and_modify_array .manage_array import release_array , retain_array
9
+ from arrayfire_wrapper . defines import AFArray
10
10
11
11
class _IndexSequence (ctypes .Structure ):
12
12
"""
@@ -186,7 +186,7 @@ class IndexStructure(ctypes.Structure):
186
186
-----------
187
187
188
188
idx: key
189
- - If of type af.Array , self.idx.arr = idx, self.isSeq = False
189
+ - If of type AFArray , self.idx.arr = idx, self.isSeq = False
190
190
- If of type af.ParallelRange, self.idx.seq = idx, self.isBatch = True
191
191
- Default:, self.idx.seq = af._IndexSequence(idx)
192
192
@@ -197,26 +197,21 @@ class IndexStructure(ctypes.Structure):
197
197
198
198
"""
199
199
200
- def __init__ (self , idx : Any ) -> None :
200
+ def __init__ (self , idx : int | slice | AFArray ) -> None :
201
201
self .idx = _IndexUnion ()
202
202
self .isBatch = False
203
203
self .isSeq = True
204
204
205
- # BUG cyclic reimport
206
- # if isinstance(idx, Array):
207
- # if idx.dtype == af_bool:
208
- # self.idx.arr = everything.where(idx.arr)
209
- # else:
210
- # self.idx.arr = everything.retain_array(idx.arr)
211
-
212
- # self.isSeq = False
213
-
214
- if isinstance (idx , ParallelRange ):
205
+ if isinstance (idx , int ) or isinstance (idx , slice ):
206
+ self .idx .seq = _IndexSequence (idx )
207
+ elif isinstance (idx , ParallelRange ):
215
208
self .idx .seq = idx
216
209
self .isBatch = True
217
-
210
+ elif isinstance (idx , AFArray ):
211
+ self .idx .arr = retain_array (idx )
212
+ self .isSeq = False
218
213
else :
219
- self . idx . seq = _IndexSequence ( idx )
214
+ raise IndexError ( "Invalid type while indexing arrayfire.array" )
220
215
221
216
def __del__ (self ) -> None :
222
217
if not self .isSeq :
@@ -247,7 +242,7 @@ def __setitem__(self, idx: int, value: IndexStructure) -> None:
247
242
self .idxs [idx ] = value
248
243
249
244
250
- def get_indices (key : int | slice | tuple [int | slice , ...]) -> CIndexStructure : # BUG
245
+ def get_indices (key : int | slice | tuple [int | slice | AFArray , ...] | AFArray ) -> CIndexStructure : # BUG
251
246
indices = CIndexStructure ()
252
247
253
248
if isinstance (key , tuple ):
0 commit comments