Skip to content

Commit 96e71b3

Browse files
authored
fix: lib for typetracers (#558)
* fix: lib for typetracers should use its nplike when possible * restrict fix to typetracer backend * improve comment for assertion helper function
1 parent aef0d0e commit 96e71b3

File tree

3 files changed

+731
-461
lines changed

3 files changed

+731
-461
lines changed

src/vector/_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4400,7 +4400,7 @@ def _lib_of(*objects: VectorProtocol) -> Module: # NumPy-like module
44004400
if isinstance(obj, Vector):
44014401
if lib is None:
44024402
lib = obj.lib
4403-
elif lib is not obj.lib:
4403+
elif lib != obj.lib:
44044404
raise TypeError(
44054405
f"cannot use {lib} and {obj.lib} in the same calculation"
44064406
)

src/vector/backends/awkward.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,38 @@ class AwkwardProtocol(Protocol):
622622
def __getitem__(self, where: typing.Any) -> float | ak.Array | ak.Record | None: ...
623623

624624

625+
class _lib(typing.NamedTuple):
626+
"""a wrapper that respects the numpy-like interface of awkward-array and the module interface of numpy"""
627+
628+
module: types.ModuleType
629+
nplike: ak._nplikes.numpy_like.NumpyLike
630+
631+
def __eq__(self, other: typing.Any) -> bool:
632+
if isinstance(other, _lib):
633+
return self.module is other.module and self.nplike is other.nplike
634+
else:
635+
return self.module is other
636+
637+
def __ne__(self, other: typing.Any) -> bool:
638+
return not self.__eq__(other)
639+
640+
def __getattr__(self, name: str) -> typing.Any:
641+
if fun := getattr(self.nplike, name, False):
642+
return fun
643+
else:
644+
return getattr(self.module, name)
645+
646+
625647
class VectorAwkward:
626648
"""Mixin class for Awkward vectors."""
627649

628-
lib: types.ModuleType = numpy
650+
@property
651+
def lib(self): # type:ignore[no-untyped-def]
652+
if (
653+
nplike := self.layout.backend.nplike # type:ignore[attr-defined]
654+
) is ak._nplikes.typetracer.TypeTracer.instance():
655+
return _lib(module=numpy, nplike=nplike)
656+
return numpy
629657

630658
def _wrap_result(
631659
self: AwkwardProtocol,

0 commit comments

Comments
 (0)