From 02c38ac6c1bb86d2efe10bd44fd7416886bc4ad4 Mon Sep 17 00:00:00 2001 From: fselmo Date: Mon, 26 Aug 2024 21:43:12 -0600 Subject: [PATCH] (bugfix) Use a defined __hash__() for unhashable, unnamed middleware: Fixes a bug where the middleware which are not hashable, classes or the ``build`` method of builder classes, end up being given the same identifier when a name is not provided for the middleware in the onion. This fixes the issue by defining a unique hash for the base ``Web3Middleware`` class and using that hash in the case where the middleware is unhashable and therefore indistinguishable from another middleware. --- newsfragments/3463.bugfix.rst | 1 + web3/datastructures.py | 32 +++++++++++++++++++++----------- web3/middleware/base.py | 8 ++++++++ 3 files changed, 30 insertions(+), 11 deletions(-) create mode 100644 newsfragments/3463.bugfix.rst diff --git a/newsfragments/3463.bugfix.rst b/newsfragments/3463.bugfix.rst new file mode 100644 index 0000000000..40c73c441d --- /dev/null +++ b/newsfragments/3463.bugfix.rst @@ -0,0 +1 @@ +Specify a unique ``__hash__()`` for unhashable ``Web3Middleware`` types and use this hash as the middleware onion key when a name is not provided for the middleware. This fixes a bug where different middleware were given the same name and therefore raised errors. diff --git a/web3/datastructures.py b/web3/datastructures.py index 7b6ea08739..e9b07348b8 100644 --- a/web3/datastructures.py +++ b/web3/datastructures.py @@ -182,7 +182,7 @@ def add(self, element: TValue, name: Optional[TKey] = None) -> None: if name is None: name = cast(TKey, element) - name = self._repr_if_not_hashable(name) + name = self._build_tkey(name) if name in self._queue: if name is element: @@ -219,7 +219,7 @@ def inject( if name is None: name = cast(TKey, element) - name = self._repr_if_not_hashable(name) + name = self._build_tkey(name) self._queue.move_to_end(name, last=False) elif layer == len(self._queue): @@ -233,7 +233,7 @@ def clear(self) -> None: self._queue.clear() def replace(self, old: TKey, new: TKey) -> TValue: - old_name = self._repr_if_not_hashable(old) + old_name = self._build_tkey(old) if old_name not in self._queue: raise Web3ValueError( @@ -248,15 +248,25 @@ def replace(self, old: TKey, new: TKey) -> TValue: self._queue[old_name] = new return to_be_replaced - def _repr_if_not_hashable(self, value: TKey) -> TKey: + @staticmethod + def _build_tkey(value: TKey) -> TKey: try: value.__hash__() + return value except TypeError: - value = cast(TKey, repr(value)) - return value + # unhashable, unnamed elements + if not callable(value): + raise Web3TypeError( + f"Expected a callable or hashable type, got {type(value)}" + ) + # This will either be ``Web3Middleware`` class or the ``build`` method of a + # ``Web3MiddlewareBuilder``. Instantiate with empty ``Web3`` and use a + # unique identifier with the ``__hash__()`` as the TKey. + v = value(None) + return cast(TKey, f"{v.__class__}<{v.__hash__()}>") def remove(self, old: TKey) -> None: - old_name = self._repr_if_not_hashable(old) + old_name = self._build_tkey(old) if old_name not in self._queue: raise Web3ValueError("You can only remove something that has been added") del self._queue[old_name] @@ -270,8 +280,8 @@ def middleware(self) -> Sequence[Any]: return [(val, key) for key, val in reversed(self._queue.items())] def _replace_with_new_name(self, old: TKey, new: TKey) -> None: - old_name = self._repr_if_not_hashable(old) - new_name = self._repr_if_not_hashable(new) + old_name = self._build_tkey(old) + new_name = self._build_tkey(new) self._queue[new_name] = new found_old = False @@ -293,11 +303,11 @@ def __add__(self, other: Any) -> "NamedElementOnion[TKey, TValue]": return NamedElementOnion(cast(List[Any], combined.items())) def __contains__(self, element: Any) -> bool: - element_name = self._repr_if_not_hashable(element) + element_name = self._build_tkey(element) return element_name in self._queue def __getitem__(self, element: TKey) -> TValue: - element_name = self._repr_if_not_hashable(element) + element_name = self._build_tkey(element) return self._queue[element_name] def __len__(self) -> int: diff --git a/web3/middleware/base.py b/web3/middleware/base.py index 565d2f1415..590fce58e7 100644 --- a/web3/middleware/base.py +++ b/web3/middleware/base.py @@ -40,6 +40,14 @@ class Web3Middleware: def __init__(self, w3: Union["AsyncWeb3", "Web3"]) -> None: self._w3 = w3 + def __hash__(self) -> int: + return hash(f"{self.__class__.__name__}({str(self.__dict__)})") + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Web3Middleware): + return False + return self.__hash__() == other.__hash__() + # -- sync -- # def wrap_make_request(self, make_request: "MakeRequestFn") -> "MakeRequestFn":