diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index fb847577d5..3eb11252b4 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -1031,9 +1031,10 @@ def get_provider( if provider_name in self.providers: provider = self.providers[provider_name](provider_settings=provider_settings) - if provider.connection_id in ProviderContextManager.connected_providers: + connection_id = provider.connection_id + if connection_id in ProviderContextManager.connected_providers: # Likely multi-chain testing or utilizing multiple on-going connections. - return ProviderContextManager.connected_providers[provider.connection_id] + return ProviderContextManager.connected_providers[connection_id] return provider diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 0c8dd2d17c..aa1bc6443d 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -1219,7 +1219,6 @@ def _complete_connect(self): logger.info(f"Connecting to a '{client_name}' node.") self.web3.eth.set_gas_price_strategy(rpc_gas_price_strategy) - # Check for chain errors, including syncing try: chain_id = self.web3.eth.chain_id @@ -1230,15 +1229,24 @@ def _complete_connect(self): else "Error getting chain id." ) - try: - block = self.web3.eth.get_block("latest") - except ExtraDataLengthError: - is_likely_poa = True - else: - is_likely_poa = ( - "proofOfAuthorityData" in block - or len(block.get("extraData", "")) > MAX_EXTRADATA_LENGTH - ) + is_likely_poa = False + + # NOTE: We have to check both earliest and latest + # because if the chain was _ever_ PoA, we need + # this middleware. + for option in ("earliest", "latest"): + try: + block = self.web3.eth.get_block(option) # type: ignore[arg-type] + except ExtraDataLengthError: + is_likely_poa = True + break + else: + is_likely_poa = ( + "proofOfAuthorityData" in block + or len(block.get("extraData", "")) > MAX_EXTRADATA_LENGTH + ) + if is_likely_poa: + break if is_likely_poa and geth_poa_middleware not in self.web3.middleware_onion: self.web3.middleware_onion.inject(geth_poa_middleware, layer=0) diff --git a/tests/functional/geth/test_provider.py b/tests/functional/geth/test_provider.py index 31bf2386ea..d657ed2f44 100644 --- a/tests/functional/geth/test_provider.py +++ b/tests/functional/geth/test_provider.py @@ -6,6 +6,8 @@ from eth_typing import HexStr from evmchains import PUBLIC_CHAIN_META from hexbytes import HexBytes +from web3.exceptions import ExtraDataLengthError +from web3.middleware import geth_poa_middleware from ape.exceptions import ( APINotImplementedError, @@ -25,6 +27,11 @@ from tests.conftest import GETH_URI, geth_process_test +@pytest.fixture +def web3_factory(mocker): + return mocker.patch("ape_ethereum.provider._create_web3") + + @geth_process_test def test_uri(geth_provider): assert geth_provider.http_uri == GETH_URI @@ -107,15 +114,14 @@ def test_chain_id_live_network_connected_uses_web3_chain_id(mocker, geth_provide @geth_process_test -def test_connect_wrong_chain_id(mocker, ethereum, geth_provider): +def test_connect_wrong_chain_id(ethereum, geth_provider, web3_factory): start_network = geth_provider.network try: geth_provider.network = ethereum.get_network("goerli") # Ensure when reconnecting, it does not use HTTP - factory = mocker.patch("ape_ethereum.provider._create_web3") - factory.return_value = geth_provider._web3 + web3_factory.return_value = geth_provider._web3 expected_error_message = ( f"Provider connected to chain ID '{geth_provider._web3.eth.chain_id}', " "which does not match network chain ID '5'. " @@ -128,6 +134,27 @@ def test_connect_wrong_chain_id(mocker, ethereum, geth_provider): geth_provider.network = start_network +@geth_process_test +def test_connect_to_chain_that_started_poa(mock_web3, web3_factory, ethereum): + """ + Ensure that when connecting to a chain that + started out as PoA, such as Goerli, we include + the right middleware. Note: even if the chain + is no longer PoA, we still need the middleware + to fetch blocks during the PoA portion of the chain. + """ + mock_web3.eth.get_block.side_effect = ExtraDataLengthError + mock_web3.eth.chain_id = ethereum.goerli.chain_id + web3_factory.return_value = mock_web3 + provider = ethereum.goerli.get_provider("geth") + provider.provider_settings = {"uri": "http://node.example.com"} # fake + provider.connect() + + # Verify PoA middleware was added. + assert mock_web3.middleware_onion.inject.call_args[0] == (geth_poa_middleware,) + assert mock_web3.middleware_onion.inject.call_args[1] == {"layer": 0} + + @geth_process_test @pytest.mark.parametrize("block_id", (0, "0", "0x0", HexStr("0x0"))) def test_get_block(geth_provider, block_id):