Skip to content

Commit

Permalink
fix: issue with networks that started PoA but currently are not (#1911)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Feb 3, 2024
1 parent a3e14db commit e979257
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 15 deletions.
5 changes: 3 additions & 2 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,9 +1031,10 @@ def get_provider(

if provider_name in self.providers:
provider = self.providers[provider_name](provider_settings=provider_settings)
if provider.connection_id in ProviderContextManager.connected_providers:
connection_id = provider.connection_id
if connection_id in ProviderContextManager.connected_providers:
# Likely multi-chain testing or utilizing multiple on-going connections.
return ProviderContextManager.connected_providers[provider.connection_id]
return ProviderContextManager.connected_providers[connection_id]

return provider

Expand Down
28 changes: 18 additions & 10 deletions src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,6 @@ def _complete_connect(self):
logger.info(f"Connecting to a '{client_name}' node.")

self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy)

# Check for chain errors, including syncing
try:
chain_id = self.web3.eth.chain_id
Expand All @@ -1230,15 +1229,24 @@ def _complete_connect(self):
else "Error getting chain id."
)

try:
block = self.web3.eth.get_block("latest")
except ExtraDataLengthError:
is_likely_poa = True
else:
is_likely_poa = (
"proofOfAuthorityData" in block
or len(block.get("extraData", "")) > MAX_EXTRADATA_LENGTH
)
is_likely_poa = False

# NOTE: We have to check both earliest and latest
# because if the chain was _ever_ PoA, we need
# this middleware.
for option in ("earliest", "latest"):
try:
block = self.web3.eth.get_block(option) # type: ignore[arg-type]
except ExtraDataLengthError:
is_likely_poa = True
break
else:
is_likely_poa = (
"proofOfAuthorityData" in block
or len(block.get("extraData", "")) > MAX_EXTRADATA_LENGTH
)
if is_likely_poa:
break

if is_likely_poa and geth_poa_middleware not in self.web3.middleware_onion:
self.web3.middleware_onion.inject(geth_poa_middleware, layer=0)
Expand Down
33 changes: 30 additions & 3 deletions tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from eth_typing import HexStr
from evmchains import PUBLIC_CHAIN_META
from hexbytes import HexBytes
from web3.exceptions import ExtraDataLengthError
from web3.middleware import geth_poa_middleware

from ape.exceptions import (
APINotImplementedError,
Expand All @@ -25,6 +27,11 @@
from tests.conftest import GETH_URI, geth_process_test


@pytest.fixture
def web3_factory(mocker):
return mocker.patch("ape_ethereum.provider._create_web3")


@geth_process_test
def test_uri(geth_provider):
assert geth_provider.http_uri == GETH_URI
Expand Down Expand Up @@ -107,15 +114,14 @@ def test_chain_id_live_network_connected_uses_web3_chain_id(mocker, geth_provide


@geth_process_test
def test_connect_wrong_chain_id(mocker, ethereum, geth_provider):
def test_connect_wrong_chain_id(ethereum, geth_provider, web3_factory):
start_network = geth_provider.network

try:
geth_provider.network = ethereum.get_network("goerli")

# Ensure when reconnecting, it does not use HTTP
factory = mocker.patch("ape_ethereum.provider._create_web3")
factory.return_value = geth_provider._web3
web3_factory.return_value = geth_provider._web3
expected_error_message = (
f"Provider connected to chain ID '{geth_provider._web3.eth.chain_id}', "
"which does not match network chain ID '5'. "
Expand All @@ -128,6 +134,27 @@ def test_connect_wrong_chain_id(mocker, ethereum, geth_provider):
geth_provider.network = start_network


@geth_process_test
def test_connect_to_chain_that_started_poa(mock_web3, web3_factory, ethereum):
"""
Ensure that when connecting to a chain that
started out as PoA, such as Goerli, we include
the right middleware. Note: even if the chain
is no longer PoA, we still need the middleware
to fetch blocks during the PoA portion of the chain.
"""
mock_web3.eth.get_block.side_effect = ExtraDataLengthError
mock_web3.eth.chain_id = ethereum.goerli.chain_id
web3_factory.return_value = mock_web3
provider = ethereum.goerli.get_provider("geth")
provider.provider_settings = {"uri": "http://node.example.com"} # fake
provider.connect()

# Verify PoA middleware was added.
assert mock_web3.middleware_onion.inject.call_args[0] == (geth_poa_middleware,)
assert mock_web3.middleware_onion.inject.call_args[1] == {"layer": 0}


@geth_process_test
@pytest.mark.parametrize("block_id", (0, "0", "0x0", HexStr("0x0")))
def test_get_block(geth_provider, block_id):
Expand Down

0 comments on commit e979257

Please sign in to comment.