From 8ca995abc41a243320cf2183c7c97d2d783defe3 Mon Sep 17 00:00:00 2001 From: Andrew Fleming Date: Fri, 16 Dec 2022 00:48:23 -0500 Subject: [PATCH] Fix network handling (#303) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add other networks to node.json by default * clean up test * Apply suggestions from code review Co-authored-by: Eric Nordelo * centralize file writes * remove writing DEFAULT_NETWORKS to node.json * fix formatting * compare gateway to default * fix get_gateways, refactor create_node_json * fix node timer and test * fix node_json and get_gateways tests * fix test * Update src/nile/common.py Co-authored-by: Martín Triay * improve write_node_json * change gateways variable to custom_gateways Co-authored-by: Eric Nordelo Co-authored-by: Martín Triay --- src/nile/common.py | 38 +++++++++++++++++---------- src/nile/core/node.py | 10 +++---- tests/test_cli.py | 11 ++++---- tests/test_common.py | 61 ++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 93 insertions(+), 27 deletions(-) diff --git a/src/nile/common.py b/src/nile/common.py index 1d7dd480..0180e149 100644 --- a/src/nile/common.py +++ b/src/nile/common.py @@ -28,25 +28,35 @@ # subject to change "0x041a78e741e5af2fec34b695679bc6891742439f7afb8484ecd7766661ad02bf" ) +DEFAULT_GATEWAYS = { + "localhost": "http://127.0.0.1:5050/", + "goerli2": "https://alpha4-2.starknet.io", + "integration": "https://external.integration.starknet.io", +} def get_gateways(): """Get the StarkNet node details.""" - try: + if os.path.exists(NODE_FILENAME): with open(NODE_FILENAME, "r") as f: - gateway = json.load(f) - return gateway - - except FileNotFoundError: - with open(NODE_FILENAME, "w") as f: - networks = { - "localhost": "http://127.0.0.1:5050/", - "goerli2": "https://alpha4-2.starknet.io", - "integration": "https://external.integration.starknet.io", - } - f.write(json.dumps(networks, indent=2)) - - return networks + custom_gateways = json.load(f) + gateways = {**DEFAULT_GATEWAYS, **custom_gateways} + return gateways + else: + return DEFAULT_GATEWAYS + + +def write_node_json(network, gateway_url): + """Create or update node.json with custom network.""" + if not os.path.exists(NODE_FILENAME): + with open(NODE_FILENAME, "w") as fp: + json.dump({network: gateway_url}, fp) + else: + with open(NODE_FILENAME, "r+") as fp: + gateways = json.load(fp) + gateways[network] = gateway_url + fp.seek(0) + json.dump(gateways, fp, indent=2) GATEWAYS = get_gateways() diff --git a/src/nile/core/node.py b/src/nile/core/node.py index 127fa28d..58aed672 100644 --- a/src/nile/core/node.py +++ b/src/nile/core/node.py @@ -1,25 +1,21 @@ """Command to start StarkNet local network.""" -import json import logging import subprocess -from nile.common import NODE_FILENAME +from nile.common import DEFAULT_GATEWAYS, write_node_json def node(host="127.0.0.1", port=5050, seed=None, lite_mode=False): """Start StarkNet local network.""" try: # Save host and port information to be used by other commands - file = NODE_FILENAME if host == "127.0.0.1": network = "localhost" else: network = host gateway_url = f"http://{host}:{port}/" - gateway = {network: gateway_url} - - with open(file, "w+") as f: - json.dump(gateway, f) + if DEFAULT_GATEWAYS.get(network) != gateway_url: + write_node_json(network, gateway_url) command = ["starknet-devnet", "--host", host, "--port", str(port)] diff --git a/tests/test_cli.py b/tests/test_cli.py index bf981c41..aee43637 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -148,7 +148,7 @@ async def test_node_forwards_args(mock_subprocess): ) async def test_node_runs_gateway(opts, expected): # Node life - seconds = 15 + seconds = 20 host = opts.get("--host", "127.0.0.1") port = opts.get("--port", "5050") @@ -177,10 +177,11 @@ async def test_node_runs_gateway(opts, expected): assert status == 200 # Assert network and gateway_url is correct in node.json file - file = NODE_FILENAME - with open(file, "r") as f: - gateway = json.load(f) - assert gateway.get(network) == expected + if expected != "http://127.0.0.1:5050/": + file = NODE_FILENAME + with open(file, "r") as f: + gateway = json.load(f) + assert gateway.get(network) == expected @pytest.mark.asyncio diff --git a/tests/test_common.py b/tests/test_common.py index a9302867..12d61327 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,7 +1,17 @@ """Tests for common library.""" +import json + import pytest -from nile.common import parse_information, prepare_params, stringify +from nile.common import ( + DEFAULT_GATEWAYS, + NODE_FILENAME, + write_node_json, + get_gateways, + parse_information, + prepare_params, + stringify, +) NETWORK = "goerli" ARGS = ["1", "2", "3"] @@ -12,6 +22,12 @@ STDOUT_2 = "SDTOUT_2" +@pytest.fixture(autouse=True) +def tmp_working_dir(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + return tmp_path + + @pytest.mark.parametrize( "args, expected", [ @@ -48,3 +64,46 @@ def test_parse_information(): _a, _b = parse_information(target) assert _a, _b == (a, b) + + +@pytest.mark.parametrize( + "network, url, gateway", + [ + (None, None, {}), + ("localhost", "5051", {"localhost": "5051"}), + ("host", "port", {"host": "port"}), + ], +) +def test_get_gateways(network, url, gateway): + if network is not None: + write_node_json(network, url) + + gateways = get_gateways() + expected = {**DEFAULT_GATEWAYS, **gateway} + assert gateways == expected + + # Check that node.json gateway returns in the case of duplicate keys + if network == "localhost": + assert expected["localhost"] != "5050" + assert expected["localhost"] == "5051" + + +@pytest.mark.parametrize( + "args1, args2, gateways", + [ + ( + ["NETWORK1", "URL1"], + ["NETWORK2", "URL2"], + {"NETWORK1": "URL1", "NETWORK2": "URL2"}, + ), + ], +) +def test_write_node_json(args1, args2, gateways): + # Check that node.json is created and adds keys + write_node_json(*args1) + write_node_json(*args2) + + with open(NODE_FILENAME, "r") as fp: + result = fp.read() + expected = json.dumps(gateways, indent=2) + assert result == expected