Skip to content

Commit

Permalink
(bugfix) Use a defined __hash__() for unhashable, unnamed middleware:
Browse files Browse the repository at this point in the history
Fixes a bug where the middleware which are not hashable, classes or the
``build`` method of builder classes, end up being given the same
identifier when a name is not provided for the middleware in the onion.
This fixes the issue by defining a unique hash for the base
``Web3Middleware`` class and using that hash in the case where the
middleware is unhashable and therefore indistinguishable from another
middleware.
  • Loading branch information
fselmo committed Aug 28, 2024
1 parent 9e21be0 commit 02c38ac
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
1 change: 1 addition & 0 deletions newsfragments/3463.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Specify a unique ``__hash__()`` for unhashable ``Web3Middleware`` types and use this hash as the middleware onion key when a name is not provided for the middleware. This fixes a bug where different middleware were given the same name and therefore raised errors.
32 changes: 21 additions & 11 deletions web3/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def add(self, element: TValue, name: Optional[TKey] = None) -> None:
if name is None:
name = cast(TKey, element)

name = self._repr_if_not_hashable(name)
name = self._build_tkey(name)

if name in self._queue:
if name is element:
Expand Down Expand Up @@ -219,7 +219,7 @@ def inject(
if name is None:
name = cast(TKey, element)

name = self._repr_if_not_hashable(name)
name = self._build_tkey(name)

self._queue.move_to_end(name, last=False)
elif layer == len(self._queue):
Expand All @@ -233,7 +233,7 @@ def clear(self) -> None:
self._queue.clear()

def replace(self, old: TKey, new: TKey) -> TValue:
old_name = self._repr_if_not_hashable(old)
old_name = self._build_tkey(old)

if old_name not in self._queue:
raise Web3ValueError(
Expand All @@ -248,15 +248,25 @@ def replace(self, old: TKey, new: TKey) -> TValue:
self._queue[old_name] = new
return to_be_replaced

def _repr_if_not_hashable(self, value: TKey) -> TKey:
@staticmethod
def _build_tkey(value: TKey) -> TKey:
try:
value.__hash__()
return value
except TypeError:
value = cast(TKey, repr(value))
return value
# unhashable, unnamed elements
if not callable(value):
raise Web3TypeError(
f"Expected a callable or hashable type, got {type(value)}"
)
# This will either be ``Web3Middleware`` class or the ``build`` method of a
# ``Web3MiddlewareBuilder``. Instantiate with empty ``Web3`` and use a
# unique identifier with the ``__hash__()`` as the TKey.
v = value(None)
return cast(TKey, f"{v.__class__}<{v.__hash__()}>")

def remove(self, old: TKey) -> None:
old_name = self._repr_if_not_hashable(old)
old_name = self._build_tkey(old)
if old_name not in self._queue:
raise Web3ValueError("You can only remove something that has been added")
del self._queue[old_name]
Expand All @@ -270,8 +280,8 @@ def middleware(self) -> Sequence[Any]:
return [(val, key) for key, val in reversed(self._queue.items())]

def _replace_with_new_name(self, old: TKey, new: TKey) -> None:
old_name = self._repr_if_not_hashable(old)
new_name = self._repr_if_not_hashable(new)
old_name = self._build_tkey(old)
new_name = self._build_tkey(new)

self._queue[new_name] = new
found_old = False
Expand All @@ -293,11 +303,11 @@ def __add__(self, other: Any) -> "NamedElementOnion[TKey, TValue]":
return NamedElementOnion(cast(List[Any], combined.items()))

def __contains__(self, element: Any) -> bool:
element_name = self._repr_if_not_hashable(element)
element_name = self._build_tkey(element)
return element_name in self._queue

def __getitem__(self, element: TKey) -> TValue:
element_name = self._repr_if_not_hashable(element)
element_name = self._build_tkey(element)
return self._queue[element_name]

def __len__(self) -> int:
Expand Down
8 changes: 8 additions & 0 deletions web3/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class Web3Middleware:
def __init__(self, w3: Union["AsyncWeb3", "Web3"]) -> None:
self._w3 = w3

def __hash__(self) -> int:
return hash(f"{self.__class__.__name__}({str(self.__dict__)})")

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Web3Middleware):
return False
return self.__hash__() == other.__hash__()

# -- sync -- #

def wrap_make_request(self, make_request: "MakeRequestFn") -> "MakeRequestFn":
Expand Down

0 comments on commit 02c38ac

Please sign in to comment.