Skip to content

feat: allow overriding of protocol methods in subclasses #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# or https://github.com/scikit-hep/vector for details.

import typing
from contextlib import suppress

from vector._typeutils import (
BoolCollection,
Expand Down Expand Up @@ -2533,6 +2534,16 @@ def _from_signature(
]


def _get_handler_index(obj: VectorProtocol) -> int:
"""Returns the index of the first valid handler checking the list of parent classes"""
for cls in type(obj).__mro__:
with suppress(ValueError):
return _handler_priority.index(cls.__module__)
raise AssertionError(
f"Could not find a valid handler for {obj}! This should not happen."
)


def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
Determines which vector should wrap the output of a dispatched function.
Expand All @@ -2544,13 +2555,12 @@ def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
handler = None
for obj in objects:
if isinstance(obj, Vector):
if handler is None:
handler = obj
elif _handler_priority.index(
type(obj).__module__
) > _handler_priority.index(type(handler).__module__):
handler = obj
if not isinstance(obj, Vector):
continue
if handler is None:
handler = obj
elif _get_handler_index(obj) > _get_handler_index(handler):
handler = obj

assert handler is not None
return handler
Expand Down
17 changes: 17 additions & 0 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2019-2021, Jonas Eschle, Jim Pivarski, Eduardo Rodrigues, and Henry Schreiner.
#
# Distributed under the 3-clause BSD license, see accompanying file LICENSE
# or https://github.com/scikit-hep/vector for details.

import vector


class CustomVector(vector.VectorObject4D):
pass


def test_handler_of():
object_a = CustomVector.from_xyzt(0.0, 0.0, 0.0, 0.0)
object_b = CustomVector.from_xyzt(1.0, 1.0, 1.0, 1.0)
protocol = vector._methods._handler_of(object_a, object_b)
assert protocol == object_a