@@ -325,6 +325,21 @@ def as_integer_slice(value):
325325 return slice (start , stop , step )
326326
327327
328+ class IndexCallable :
329+ """Provide getitem syntax for a callable object."""
330+
331+ __slots__ = ("func" ,)
332+
333+ def __init__ (self , func ):
334+ self .func = func
335+
336+ def __getitem__ (self , key ):
337+ return self .func (key )
338+
339+ def __setitem__ (self , key , value ):
340+ raise NotImplementedError
341+
342+
328343class BasicIndexer (ExplicitIndexer ):
329344 """Tuple for basic indexing.
330345
@@ -470,6 +485,13 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
470485 # Note this is the base class for all lazy indexing classes
471486 return np .asarray (self .get_duck_array (), dtype = dtype )
472487
488+ def _oindex_get (self , key ):
489+ raise NotImplementedError ("This method should be overridden" )
490+
491+ @property
492+ def oindex (self ):
493+ return IndexCallable (self ._oindex_get )
494+
473495
474496class ImplicitToExplicitIndexingAdapter (NDArrayMixin ):
475497 """Wrap an array, converting tuples into the indicated explicit indexer."""
@@ -560,6 +582,9 @@ def get_duck_array(self):
560582 def transpose (self , order ):
561583 return LazilyVectorizedIndexedArray (self .array , self .key ).transpose (order )
562584
585+ def _oindex_get (self , indexer ):
586+ return type (self )(self .array , self ._updated_key (indexer ))
587+
563588 def __getitem__ (self , indexer ):
564589 if isinstance (indexer , VectorizedIndexer ):
565590 array = LazilyVectorizedIndexedArray (self .array , self .key )
@@ -663,6 +688,9 @@ def _ensure_copied(self):
663688 def get_duck_array (self ):
664689 return self .array .get_duck_array ()
665690
691+ def _oindex_get (self , key ):
692+ return type (self )(_wrap_numpy_scalars (self .array [key ]))
693+
666694 def __getitem__ (self , key ):
667695 return type (self )(_wrap_numpy_scalars (self .array [key ]))
668696
@@ -696,6 +724,9 @@ def get_duck_array(self):
696724 self ._ensure_cached ()
697725 return self .array .get_duck_array ()
698726
727+ def _oindex_get (self , key ):
728+ return type (self )(_wrap_numpy_scalars (self .array [key ]))
729+
699730 def __getitem__ (self , key ):
700731 return type (self )(_wrap_numpy_scalars (self .array [key ]))
701732
@@ -1332,6 +1363,10 @@ def _indexing_array_and_key(self, key):
13321363 def transpose (self , order ):
13331364 return self .array .transpose (order )
13341365
1366+ def _oindex_get (self , key ):
1367+ array , key = self ._indexing_array_and_key (key )
1368+ return array [key ]
1369+
13351370 def __getitem__ (self , key ):
13361371 array , key = self ._indexing_array_and_key (key )
13371372 return array [key ]
@@ -1376,16 +1411,19 @@ def __init__(self, array):
13761411 )
13771412 self .array = array
13781413
1414+ def _oindex_get (self , key ):
1415+ # manual orthogonal indexing (implemented like DaskIndexingAdapter)
1416+ key = key .tuple
1417+ value = self .array
1418+ for axis , subkey in reversed (list (enumerate (key ))):
1419+ value = value [(slice (None ),) * axis + (subkey , Ellipsis )]
1420+ return value
1421+
13791422 def __getitem__ (self , key ):
13801423 if isinstance (key , BasicIndexer ):
13811424 return self .array [key .tuple ]
13821425 elif isinstance (key , OuterIndexer ):
1383- # manual orthogonal indexing (implemented like DaskIndexingAdapter)
1384- key = key .tuple
1385- value = self .array
1386- for axis , subkey in reversed (list (enumerate (key ))):
1387- value = value [(slice (None ),) * axis + (subkey , Ellipsis )]
1388- return value
1426+ return self .oindex [key ]
13891427 else :
13901428 if isinstance (key , VectorizedIndexer ):
13911429 raise TypeError ("Vectorized indexing is not supported" )
@@ -1395,11 +1433,10 @@ def __getitem__(self, key):
13951433 def __setitem__ (self , key , value ):
13961434 if isinstance (key , (BasicIndexer , OuterIndexer )):
13971435 self .array [key .tuple ] = value
1436+ elif isinstance (key , VectorizedIndexer ):
1437+ raise TypeError ("Vectorized indexing is not supported" )
13981438 else :
1399- if isinstance (key , VectorizedIndexer ):
1400- raise TypeError ("Vectorized indexing is not supported" )
1401- else :
1402- raise TypeError (f"Unrecognized indexer: { key } " )
1439+ raise TypeError (f"Unrecognized indexer: { key } " )
14031440
14041441 def transpose (self , order ):
14051442 xp = self .array .__array_namespace__ ()
@@ -1417,24 +1454,25 @@ def __init__(self, array):
14171454 """
14181455 self .array = array
14191456
1420- def __getitem__ (self , key ):
1457+ def _oindex_get (self , key ):
1458+ key = key .tuple
1459+ try :
1460+ return self .array [key ]
1461+ except NotImplementedError :
1462+ # manual orthogonal indexing
1463+ value = self .array
1464+ for axis , subkey in reversed (list (enumerate (key ))):
1465+ value = value [(slice (None ),) * axis + (subkey ,)]
1466+ return value
14211467
1468+ def __getitem__ (self , key ):
14221469 if isinstance (key , BasicIndexer ):
14231470 return self .array [key .tuple ]
14241471 elif isinstance (key , VectorizedIndexer ):
14251472 return self .array .vindex [key .tuple ]
14261473 else :
14271474 assert isinstance (key , OuterIndexer )
1428- key = key .tuple
1429- try :
1430- return self .array [key ]
1431- except NotImplementedError :
1432- # manual orthogonal indexing.
1433- # TODO: port this upstream into dask in a saner way.
1434- value = self .array
1435- for axis , subkey in reversed (list (enumerate (key ))):
1436- value = value [(slice (None ),) * axis + (subkey ,)]
1437- return value
1475+ return self .oindex [key ]
14381476
14391477 def __setitem__ (self , key , value ):
14401478 if isinstance (key , BasicIndexer ):
@@ -1510,6 +1548,9 @@ def _convert_scalar(self, item):
15101548 # a NumPy array.
15111549 return to_0d_array (item )
15121550
1551+ def _oindex_get (self , key ):
1552+ return self .__getitem__ (key )
1553+
15131554 def __getitem__ (
15141555 self , indexer
15151556 ) -> (
0 commit comments