From 56392e40ac039a6080fb675754175694746a2a4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C4=83lina=20Cenan?= Date: Tue, 9 Jan 2024 17:08:04 +0200 Subject: [PATCH] Add some test coverage (#488) * Adds a line of coverage to test. * Add coverage for csvs module. * Add coverage to check_network. * Add coverage to predictions and traction info. * Adds coverage to predictoor stats. * Adds full coverage to arg cli classes. * Adds cli arguments coverage and fix a wrong parameter in cli arguments. * Adds coverage to cli module and timeframe. * Some reformats and coverage in contract module. * Adds coverage and simplifications to contracts, except token. * Add some coverage to tokens to complete contract coverage work. --- .../aimodel/test/test_aimodel_factory.py | 15 +++ pdr_backend/analytics/check_network.py | 9 +- .../analytics/test/test_check_network.py | 37 ++++++ .../test/test_get_predictions_info.py | 23 ++++ .../analytics/test/test_get_traction_info.py | 21 ++++ .../analytics/test/test_predictoor_stats.py | 11 ++ pdr_backend/cli/arg_exchange.py | 3 - pdr_backend/cli/arg_feed.py | 3 - pdr_backend/cli/arg_pair.py | 11 +- pdr_backend/cli/cli_arguments.py | 62 ++++----- .../cli/test_noganache/test_arg_exchange.py | 5 +- .../cli/test_noganache/test_arg_feed.py | 12 +- .../cli/test_noganache/test_arg_pair.py | 7 ++ .../cli/test_noganache/test_cli_arguments.py | 41 ++++++ .../cli/test_noganache/test_cli_module.py | 80 +++++++++++- .../cli/test_noganache/test_timeframe.py | 10 ++ pdr_backend/conftest_ganache.py | 26 ++++ pdr_backend/contract/base_contract.py | 32 +++++ pdr_backend/contract/data_nft.py | 11 +- pdr_backend/contract/dfrewards.py | 5 +- pdr_backend/contract/erc721_factory.py | 7 +- pdr_backend/contract/predictoor_batcher.py | 13 +- pdr_backend/contract/predictoor_contract.py | 69 +++++----- .../contract/test/test_base_contract.py | 64 ++++++++++ .../contract/test/test_erc721_factory.py | 19 +++ .../contract/test/test_predictoor_contract.py | 18 +++ pdr_backend/contract/test/test_token.py | 19 ++- .../contract/test/test_wrapped_token.py | 21 ++++ pdr_backend/contract/token.py | 10 +- pdr_backend/contract/wrapped_token.py | 3 +- pdr_backend/ppss/test/test_web3_pp.py | 29 +++++ pdr_backend/ppss/web3_pp.py | 21 ++++ pdr_backend/util/csvs.py | 78 +++++------- pdr_backend/util/networkutil.py | 71 ----------- .../util/test_ganache/test_networkutil.py | 119 ------------------ .../util/test_ganache/test_web3_config.py | 12 ++ pdr_backend/util/test_noganache/test_csvs.py | 58 +++++++++ pdr_backend/util/web3_config.py | 17 +++ 38 files changed, 719 insertions(+), 353 deletions(-) create mode 100644 pdr_backend/cli/test_noganache/test_cli_arguments.py create mode 100644 pdr_backend/contract/test/test_wrapped_token.py create mode 100644 pdr_backend/util/test_noganache/test_csvs.py diff --git a/pdr_backend/aimodel/test/test_aimodel_factory.py b/pdr_backend/aimodel/test/test_aimodel_factory.py index 3b4b3c835..6081c07d1 100644 --- a/pdr_backend/aimodel/test/test_aimodel_factory.py +++ b/pdr_backend/aimodel/test/test_aimodel_factory.py @@ -1,6 +1,8 @@ import warnings +from unittest.mock import Mock import numpy as np +import pytest from enforce_typing import enforce_types from pdr_backend.aimodel.aimodel_factory import AimodelFactory @@ -82,3 +84,16 @@ def test_aimodel_accuracy_from_create_xy(aimodel_factory): y_train_hat = aimodel.predict(X_train) assert sum(abs(y_train - y_train_hat)) < 1e-10 # near-perfect since linear + + +@enforce_types +def test_aimodel_factory_bad_approach(): + aimodel_ss = Mock(spec=AimodelSS) + aimodel_ss.approach = "BAD" + factory = AimodelFactory(aimodel_ss) + + X_train, y_train, _, _ = _data() + + # forcefully change the model + with pytest.raises(ValueError): + factory.build(X_train, y_train) diff --git a/pdr_backend/analytics/check_network.py b/pdr_backend/analytics/check_network.py index bd7144f54..ef157f4c2 100644 --- a/pdr_backend/analytics/check_network.py +++ b/pdr_backend/analytics/check_network.py @@ -166,12 +166,13 @@ def check_network_main(ppss: PPSS, lookback_hours: int): ocean_bal = from_wei(OCEAN.balanceOf(address)) native_bal = from_wei(web3_pp.web3_config.w3.eth.get_balance(address)) - ocean_warning = " WARNING LOW OCEAN BALANCE!" if ocean_bal < 10 else " OK " + ocean_warning = ( + " WARNING LOW OCEAN BALANCE!" + if ocean_bal < 10 and name != "trueval" + else " OK " + ) native_warning = " WARNING LOW NATIVE BALANCE!" if native_bal < 10 else " OK " - if name == "trueval": - ocean_warning = " OK " - print( f"{name}: OCEAN: {ocean_bal:.2f}{ocean_warning}" f", Native: {native_bal:.2f}{native_warning}" diff --git a/pdr_backend/analytics/test/test_check_network.py b/pdr_backend/analytics/test/test_check_network.py index dc3d42b58..cb0cc2d27 100644 --- a/pdr_backend/analytics/test/test_check_network.py +++ b/pdr_backend/analytics/test/test_check_network.py @@ -115,3 +115,40 @@ def test_check_network_main( # pylint: disable=unused-argument assert mock_query_subgraph.call_count == 1 mock_token.assert_called() ppss.web3_pp.web3_config.w3.eth.get_balance.assert_called() + + +@enforce_types +@patch(f"{PATH}.check_dfbuyer") +@patch(f"{PATH}.get_opf_addresses") +@patch(f"{PATH}.Token") +def test_check_network_others( # pylint: disable=unused-argument + mock_token, + mock_get_opf_addresses, + mock_check_dfbuyer, + tmpdir, + monkeypatch, +): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + mock_query_subgraph = Mock() + + # test if predictoor contracts are found, iterates through them + with patch(f"{PATH}.query_subgraph") as mock_query_subgraph: + mock_query_subgraph.return_value = { + "data": { + "predictContracts": [ + { + "slots": {}, + "token": {"name": "aa"}, + "secondsPerEpoch": 86400, + }, + { + "slots": {}, + "token": {"name": "bb"}, + "secondsPerEpoch": 86400, + }, + ] + } + } + check_network_main(ppss, lookback_hours=24) + assert mock_query_subgraph.call_count == 1 + assert mock_check_dfbuyer.call_count == 1 diff --git a/pdr_backend/analytics/test/test_get_predictions_info.py b/pdr_backend/analytics/test/test_get_predictions_info.py index 2d6ef1bc9..9b5df5948 100644 --- a/pdr_backend/analytics/test/test_get_predictions_info.py +++ b/pdr_backend/analytics/test/test_get_predictions_info.py @@ -46,3 +46,26 @@ def test_get_predictions_info_main_mainnet( ) mock_save.assert_called() mock_getstats.assert_called_with(_sample_first_predictions) + + +@enforce_types +def test_get_predictions_info_empty(tmpdir, capfd): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + mock_getids = Mock(return_value=[]) + mock_fetch = Mock(return_value={}) + + PATH = "pdr_backend.analytics.get_predictions_info" + with patch(f"{PATH}.get_all_contract_ids_by_owner", mock_getids), patch( + f"{PATH}.fetch_filtered_predictions", mock_fetch + ): + st_timestr = "2023-11-02" + fin_timestr = "2023-11-05" + + get_predictions_info_main( + ppss, "0x123", st_timestr, fin_timestr, "parquet_data/" + ) + + assert ( + "No records found. Please adjust start and end times" in capfd.readouterr().out + ) diff --git a/pdr_backend/analytics/test/test_get_traction_info.py b/pdr_backend/analytics/test/test_get_traction_info.py index e7f4345e9..1964eee6a 100644 --- a/pdr_backend/analytics/test/test_get_traction_info.py +++ b/pdr_backend/analytics/test/test_get_traction_info.py @@ -73,3 +73,24 @@ def test_get_traction_info_main_mainnet( pl.DataFrame.equals(mock_traction_stat.call_args, preds_df) mock_plot_cumsum.assert_called() mock_plot_daily.assert_called() + + +@enforce_types +def test_get_traction_info_empty( + tmpdir, + capfd, +): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + mock_empty = Mock(return_value=[]) + + PATH = "pdr_backend.analytics.get_traction_info" + with patch(f"{PATH}.GQLDataFactory.get_gql_dfs", mock_empty): + st_timestr = "2023-11-02" + fin_timestr = "2023-11-05" + + get_traction_info_main(ppss, st_timestr, fin_timestr, "parquet_data/") + + assert ( + "No records found. Please adjust start and end times." in capfd.readouterr().out + ) diff --git a/pdr_backend/analytics/test/test_predictoor_stats.py b/pdr_backend/analytics/test/test_predictoor_stats.py index 8bc351a9d..c87a37978 100644 --- a/pdr_backend/analytics/test/test_predictoor_stats.py +++ b/pdr_backend/analytics/test/test_predictoor_stats.py @@ -69,6 +69,17 @@ def test_get_cli_statistics(capsys, _sample_first_predictions): assert "Accuracy for Pair" in output assert "Accuracy for Predictoor Address" in output + get_cli_statistics([]) + assert "No predictions found" in capsys.readouterr().out + + with patch( + "pdr_backend.analytics.predictoor_stats.aggregate_prediction_statistics" + ) as mock: + mock.return_value = ({}, 0) + get_cli_statistics(_sample_first_predictions) + + assert "No correct predictions found" in capsys.readouterr().out + @enforce_types @patch("matplotlib.pyplot.savefig") diff --git a/pdr_backend/cli/arg_exchange.py b/pdr_backend/cli/arg_exchange.py index 072c61d6b..fc0cb670b 100644 --- a/pdr_backend/cli/arg_exchange.py +++ b/pdr_backend/cli/arg_exchange.py @@ -40,7 +40,4 @@ def __init__(self, exchanges: Union[List[str], List[ArgExchange]]): super().__init__(converted) def __str__(self): - if not self: - return "" - return ",".join([str(exchange) for exchange in self]) diff --git a/pdr_backend/cli/arg_feed.py b/pdr_backend/cli/arg_feed.py index 48aac2fb4..6fc33f4bc 100644 --- a/pdr_backend/cli/arg_feed.py +++ b/pdr_backend/cli/arg_feed.py @@ -55,9 +55,6 @@ def __str__(self): return feed_str - def __repr__(self): - return self.__str__() - def __eq__(self, other): return ( self.exchange == other.exchange diff --git a/pdr_backend/cli/arg_pair.py b/pdr_backend/cli/arg_pair.py index 5bc95d502..d5d4c687b 100644 --- a/pdr_backend/cli/arg_pair.py +++ b/pdr_backend/cli/arg_pair.py @@ -18,7 +18,7 @@ def __init__( base_str: Optional[str] = None, quote_str: Optional[str] = None, ): - if not pair_str and [None in [base_str, quote_str]]: + if not pair_str and None in [base_str, quote_str]: raise ValueError( "Must provide either pair_str, or both base_str and quote_str" ) @@ -55,12 +55,7 @@ def __hash__(self): class ArgPairs(List[ArgPair]): @staticmethod def from_str(pairs_str: str) -> "ArgPairs": - pairs = ArgPairs(_unpack_pairs_str(pairs_str)) - - if not pairs: - raise ValueError(pairs_str) - - return pairs + return ArgPairs(_unpack_pairs_str(pairs_str)) def __eq__(self, other): return set(self) == set(other) @@ -106,7 +101,7 @@ def _unpack_pairs_str(pairs_str: str) -> List[str]: pairs_str = pairs_str.replace("-", "/") # ETH/USDT -> ETH-USDT. Safer files. pair_str_list = pairs_str.split(",") - if not pair_str_list: + if not any(pair_str_list): raise ValueError(pairs_str) return pair_str_list diff --git a/pdr_backend/cli/cli_arguments.py b/pdr_backend/cli/cli_arguments.py index 7adb2ad58..d590998a0 100644 --- a/pdr_backend/cli/cli_arguments.py +++ b/pdr_backend/cli/cli_arguments.py @@ -177,7 +177,9 @@ class _ArgParser_ST_END_PQDIR_NETWORK_PPSS_PDRS( @enforce_types def __init__(self, description: str, command_name: str): super().__init__(description=description) - self.add_arguments_bulk(command_name, ["ST", "END", "PQDIR", "PPSS", "NETWORK"]) + self.add_arguments_bulk( + command_name, ["ST", "END", "PQDIR", "PPSS", "NETWORK", "PDRS"] + ) @enforce_types @@ -249,37 +251,35 @@ def print_args(arguments: Namespace): TopupArgParser = _ArgParser_PPSS_NETWORK +defined_parsers = { + "do_sim": SimArgParser("Run simulation", "sim"), + "do_predictoor": PredictoorArgParser("Run a predictoor bot", "predictoor"), + "do_trader": TraderArgParser("Run a trader bot", "trader"), + "do_lake": LakeArgParser("Run the lake tool", "lake"), + "do_claim_OCEAN": ClaimOceanArgParser("Claim OCEAN", "claim_OCEAN"), + "do_claim_ROSE": ClaimRoseArgParser("Claim ROSE", "claim_ROSE"), + "do_get_predictoors_info": GetPredictoorsInfoArgParser( + "For specified predictoors, report {accuracy, ..} of each predictoor", + "get_predictoors_info", + ), + "do_get_predictions_info": GetPredictionsInfoArgParser( + "For specified feeds, report {accuracy, ..} of each predictoor", + "get_predictions_info", + ), + "do_get_traction_info": GetTractionInfoArgParser( + "Get traction info: # predictoors vs time, etc", + "get_traction_info", + ), + "do_check_network": CheckNetworkArgParser("Check network", "check_network"), + "do_trueval": TruevalArgParser("Run trueval bot", "trueval"), + "do_dfbuyer": DfbuyerArgParser("Run dfbuyer bot", "dfbuyer"), + "do_publisher": PublisherArgParser("Publish feeds", "publisher"), + "do_topup": TopupArgParser("Topup OCEAN and ROSE in dfbuyer, trueval, ..", "topup"), +} + def get_arg_parser(func_name): - parsers = { - "do_sim": SimArgParser("Run simulation", "sim"), - "do_predictoor": PredictoorArgParser("Run a predictoor bot", "predictoor"), - "do_trader": TraderArgParser("Run a trader bot", "trader"), - "do_lake": LakeArgParser("Run the lake tool", "lake"), - "do_claim_OCEAN": ClaimOceanArgParser("Claim OCEAN", "claim_OCEAN"), - "do_claim_ROSE": ClaimRoseArgParser("Claim ROSE", "claim_ROSE"), - "do_get_predictoors_info": GetPredictoorsInfoArgParser( - "For specified predictoors, report {accuracy, ..} of each predictoor", - "get_predictoors_info", - ), - "do_get_predictions_info": GetPredictionsInfoArgParser( - "For specified feeds, report {accuracy, ..} of each predictoor", - "get_predictions_info", - ), - "do_get_traction_info": GetTractionInfoArgParser( - "Get traction info: # predictoors vs time, etc", - "get_traction_info", - ), - "do_check_network": CheckNetworkArgParser("Check network", "check_network"), - "do_trueval": TruevalArgParser("Run trueval bot", "trueval"), - "do_dfbuyer": DfbuyerArgParser("Run dfbuyer bot", "dfbuyer"), - "do_publisher": PublisherArgParser("Publish feeds", "publisher"), - "do_topup": TopupArgParser( - "Topup OCEAN and ROSE in dfbuyer, trueval, ..", "topup" - ), - } - - if func_name not in parsers: + if func_name not in defined_parsers: raise ValueError(f"Unknown function name: {func_name}") - return parsers[func_name] + return defined_parsers[func_name] diff --git a/pdr_backend/cli/test_noganache/test_arg_exchange.py b/pdr_backend/cli/test_noganache/test_arg_exchange.py index 133b086d3..34005ecb7 100644 --- a/pdr_backend/cli/test_noganache/test_arg_exchange.py +++ b/pdr_backend/cli/test_noganache/test_arg_exchange.py @@ -1,7 +1,7 @@ import pytest from enforce_typing import enforce_types -from pdr_backend.cli.arg_exchange import ArgExchanges +from pdr_backend.cli.arg_exchange import ArgExchange, ArgExchanges @enforce_types @@ -15,6 +15,9 @@ def test_pack_exchange_str_list(): with pytest.raises(TypeError): ArgExchanges(None) + with pytest.raises(ValueError): + ArgExchange(None) + with pytest.raises(TypeError): ArgExchanges("") diff --git a/pdr_backend/cli/test_noganache/test_arg_feed.py b/pdr_backend/cli/test_noganache/test_arg_feed.py index d78780923..b4368eabe 100644 --- a/pdr_backend/cli/test_noganache/test_arg_feed.py +++ b/pdr_backend/cli/test_noganache/test_arg_feed.py @@ -15,12 +15,16 @@ def test_unpack_feeds_strs(): assert ArgFeeds.from_strs(["binance ADA-USDT o 1h"]) == target_feeds # 1 str w 2 feeds, 2 feeds total - target_feeds = [ - ArgFeed("binance", "open", "ADA/USDT"), - ArgFeed("binance", "high", "ADA/USDT"), - ] + target_feeds = ArgFeeds( + [ + ArgFeed("binance", "open", "ADA/USDT"), + ArgFeed("binance", "high", "ADA/USDT"), + ] + ) assert ArgFeeds.from_strs(["binance ADA/USDT oh"]) == target_feeds assert ArgFeeds.from_strs(["binance ADA-USDT oh"]) == target_feeds + assert target_feeds.signals == set(["open", "high"]) + assert target_feeds.exchanges == set(["binance"]) # 2 strs each w 1 feed, 2 feeds total target_feeds = [ diff --git a/pdr_backend/cli/test_noganache/test_arg_pair.py b/pdr_backend/cli/test_noganache/test_arg_pair.py index 91ab5e482..0ec79e237 100644 --- a/pdr_backend/cli/test_noganache/test_arg_pair.py +++ b/pdr_backend/cli/test_noganache/test_arg_pair.py @@ -4,6 +4,7 @@ from pdr_backend.cli.arg_pair import ( ArgPair, ArgPairs, + _unpack_pairs_str, _verify_base_str, _verify_quote_str, ) @@ -19,6 +20,9 @@ def test_unpack_pair_str(): @enforce_types def test_unpack_pairs_str(): + with pytest.raises(ValueError): + _unpack_pairs_str("") + assert ArgPairs.from_str("ADA-USDT BTC/USDT") == ["ADA/USDT", "BTC/USDT"] assert ArgPairs.from_str("ADA/USDT,BTC/USDT") == ["ADA/USDT", "BTC/USDT"] assert ArgPairs.from_str("ADA/USDT, BTC/USDT") == ["ADA/USDT", "BTC/USDT"] @@ -53,6 +57,9 @@ def test_pack_pair_str_list(): with pytest.raises(ValueError): ArgPairs(["ADA-USDT fgds"]) + pair_from_base_and_quote = ArgPair(base_str="BTC", quote_str="USDT") + assert str(ArgPair(pair_from_base_and_quote)) == "BTC/USDT" + @enforce_types def test_verify_pairs_str__and__verify_pair_str(): diff --git a/pdr_backend/cli/test_noganache/test_cli_arguments.py b/pdr_backend/cli/test_noganache/test_cli_arguments.py new file mode 100644 index 000000000..bd0bbc4f9 --- /dev/null +++ b/pdr_backend/cli/test_noganache/test_cli_arguments.py @@ -0,0 +1,41 @@ +import pytest + +from pdr_backend.cli.cli_arguments import ( + CustomArgParser, + defined_parsers, + do_help_long, + get_arg_parser, + print_args, +) + + +def test_arg_parser(): + for arg in defined_parsers: + parser = get_arg_parser(arg) + assert isinstance(parser, CustomArgParser) + + with pytest.raises(ValueError): + get_arg_parser("xyz") + + +def test_do_help_long(capfd): + with pytest.raises(SystemExit): + do_help_long() + + out, _ = capfd.readouterr() + assert "Predictoor tool" in out + assert "Main tools:" in out + + +def test_print_args(capfd): + SimArgParser = defined_parsers["do_sim"] + parser = SimArgParser + args = ["sim", "ppss.yaml"] + parsed_args = parser.parse_args(args) + + print_args(parsed_args) + + out, _ = capfd.readouterr() + assert "dftool sim: Begin" in out + assert "Arguments:" in out + assert "PPSS_FILE=ppss.yaml" in out diff --git a/pdr_backend/cli/test_noganache/test_cli_module.py b/pdr_backend/cli/test_noganache/test_cli_module.py index 9709f3ccd..c8220d13b 100644 --- a/pdr_backend/cli/test_noganache/test_cli_module.py +++ b/pdr_backend/cli/test_noganache/test_cli_module.py @@ -1,10 +1,12 @@ import os from argparse import Namespace -from unittest.mock import Mock +from unittest.mock import Mock, patch +import pytest from enforce_typing import enforce_types from pdr_backend.cli.cli_module import ( + _do_main, do_check_network, do_claim_OCEAN, do_claim_ROSE, @@ -12,8 +14,10 @@ do_get_predictions_info, do_get_predictoors_info, do_get_traction_info, + do_lake, do_predictoor, do_publisher, + do_sim, do_topup, do_trader, do_trueval, @@ -25,6 +29,18 @@ class _APPROACH: APPROACH = 1 +class _APPROACH2: + APPROACH = 2 + + +class _APPROACH3: + APPROACH = 3 + + +class _APPROACH_BAD: + APPROACH = 99 + + class _PPSS: PPSS_FILE = os.path.abspath("ppss.yaml") @@ -79,8 +95,12 @@ class MockArgs(Namespace, _PPSS, _NETWORK): class MockArgParser_APPROACH_PPSS_NETWORK(_Base): + def __init__(self, approach=_APPROACH): + self.approach = approach + super().__init__() + def parse_args(self): - class MockArgs(Namespace, _APPROACH, _PPSS, _NETWORK): + class MockArgs(Namespace, self.approach, _PPSS, _NETWORK): pass return MockArgs() @@ -137,6 +157,15 @@ def test_do_check_network(monkeypatch): mock_f.assert_called() +@enforce_types +def test_do_lake(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.OhlcvDataFactory.get_mergedohlcv_df", mock_f) + + do_lake(MockArgParser_PPSS_NETWORK().parse_args()) + mock_f.assert_called() + + @enforce_types def test_do_claim_OCEAN(monkeypatch): mock_f = Mock() @@ -199,6 +228,14 @@ def test_do_predictoor(monkeypatch): do_predictoor(MockArgParser_APPROACH_PPSS_NETWORK().parse_args()) assert MockAgent.was_run + monkeypatch.setattr(f"{_CLI_PATH}.PredictoorAgent3", MockAgent) + + do_predictoor(MockArgParser_APPROACH_PPSS_NETWORK(_APPROACH3).parse_args()) + assert MockAgent.was_run + + with pytest.raises(ValueError): + do_predictoor(MockArgParser_APPROACH_PPSS_NETWORK(_APPROACH_BAD).parse_args()) + @enforce_types def test_do_publisher(monkeypatch): @@ -225,6 +262,14 @@ def test_do_trader(monkeypatch): do_trader(MockArgParser_APPROACH_PPSS_NETWORK().parse_args()) assert MockAgent.was_run + monkeypatch.setattr(f"{_CLI_PATH}.TraderAgent2", MockAgent) + + do_trader(MockArgParser_APPROACH_PPSS_NETWORK(_APPROACH2).parse_args()) + assert MockAgent.was_run + + with pytest.raises(ValueError): + do_trader(MockArgParser_APPROACH_PPSS_NETWORK(_APPROACH_BAD).parse_args()) + @enforce_types def test_do_trueval(monkeypatch): @@ -232,3 +277,34 @@ def test_do_trueval(monkeypatch): do_trueval(MockArgParser_PPSS_NETWORK().parse_args()) assert MockAgent.was_run + + +@enforce_types +def test_do_sim(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.SimEngine.run", mock_f) + + do_sim(MockArgParser_PPSS_NETWORK().parse_args()) + mock_f.assert_called() + + +@enforce_types +def test_do_main(monkeypatch, capfd): + with patch("sys.argv", ["dftool", "help"]): + with pytest.raises(SystemExit): + _do_main() + + assert "Predictoor tool" in capfd.readouterr().out + + with patch("sys.argv", ["dftool", "undefined_function"]): + with pytest.raises(SystemExit): + _do_main() + + assert "Predictoor tool" in capfd.readouterr().out + + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.SimEngine.run", mock_f) + with patch("sys.argv", ["dftool", "sim", "ppss.yaml"]): + _do_main() + + assert mock_f.called diff --git a/pdr_backend/cli/test_noganache/test_timeframe.py b/pdr_backend/cli/test_noganache/test_timeframe.py index ab6dbc78f..5ed01c5ad 100644 --- a/pdr_backend/cli/test_noganache/test_timeframe.py +++ b/pdr_backend/cli/test_noganache/test_timeframe.py @@ -36,6 +36,13 @@ def test_timeframe_class_bad(): with pytest.raises(ValueError): Timeframe("foo") + t = Timeframe("1h") + # forcefully change the model + t.timeframe_str = "BAD" + + with pytest.raises(ValueError): + _ = t.m + @enforce_types def test_pack_timeframe_str_list(): @@ -45,6 +52,9 @@ def test_pack_timeframe_str_list(): assert str(Timeframes.from_str("1h,5m")) == "1h,5m" + with pytest.raises(TypeError): + Timeframes.from_str(None) + with pytest.raises(TypeError): Timeframes("") diff --git a/pdr_backend/conftest_ganache.py b/pdr_backend/conftest_ganache.py index 656525aec..76e925852 100644 --- a/pdr_backend/conftest_ganache.py +++ b/pdr_backend/conftest_ganache.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock + import pytest from pdr_backend.contract.predictoor_batcher import PredictoorBatcher @@ -97,6 +99,30 @@ def predictoor_contract2(): return PredictoorContract(w3p, dt_addr) +@pytest.fixture(scope="module") # "module" = invoke once per test module +def predictoor_contract_empty(): + w3p = _web3_pp() + w3c = w3p.web3_config + _, _, _, _, logs = publish_asset( + s_per_epoch=S_PER_EPOCH, + s_per_subscription=S_PER_EPOCH * 24, + base="ETH", + quote="USDT", + source="kraken", + timeframe="5m", + trueval_submitter_addr=w3c.owner, + feeCollector_addr=w3c.owner, + rate=3, + cut=0.2, + web3_pp=w3p, + ) + dt_addr = logs["newTokenAddress"] + predictoor_c = PredictoorContract(w3p, dt_addr) + predictoor_c.get_exchanges = Mock(return_value=[]) + + return predictoor_c + + # pylint: disable=redefined-outer-name @pytest.fixture(scope="module") def predictoor_batcher(): diff --git a/pdr_backend/contract/base_contract.py b/pdr_backend/contract/base_contract.py index 8d2d980a2..454b29c29 100644 --- a/pdr_backend/contract/base_contract.py +++ b/pdr_backend/contract/base_contract.py @@ -1,6 +1,7 @@ from abc import ABC from enforce_typing import enforce_types +from sapphirepy import wrapper @enforce_types @@ -22,3 +23,34 @@ def __init__(self, web3_pp, address: str, contract_name: str): address=self.config.w3.to_checksum_address(address), abi=get_contract_abi(contract_name, web3_pp.address_file), ) + + def send_encrypted_tx( + self, + function_name, + args, + sender=None, + receiver=None, + pk=None, + value=0, # in wei + gasLimit=10000000, + gasCost=0, # in wei + nonce=0, + ) -> tuple: + sender = self.config.owner if sender is None else sender + receiver = self.contract_instance.address if receiver is None else receiver + pk = self.config.account.key.hex()[2:] if pk is None else pk + + data = self.contract_instance.encodeABI(fn_name=function_name, args=args) + rpc_url = self.config.rpc_url + + return wrapper.send_encrypted_sapphire_tx( + pk, + sender, + receiver, + rpc_url, + value, + gasLimit, + data, + gasCost, + nonce, + ) diff --git a/pdr_backend/contract/data_nft.py b/pdr_backend/contract/data_nft.py index d6ff7e38b..3d7029b6b 100644 --- a/pdr_backend/contract/data_nft.py +++ b/pdr_backend/contract/data_nft.py @@ -7,7 +7,6 @@ from web3.types import HexBytes, TxReceipt from pdr_backend.contract.base_contract import BaseContract -from pdr_backend.util.networkutil import tx_call_params @enforce_types @@ -20,7 +19,7 @@ def set_data(self, field_label, field_value, wait_for_receipt=True): field_label_hash = Web3.keccak(text=field_label) # to keccak256 hash field_value_bytes = field_value.encode() # to array of bytes - call_params = tx_call_params(self.web3_pp, gas=100000) + call_params = self.web3_pp.tx_call_params(gas=100000) tx = self.contract_instance.functions.setNewData( field_label_hash, field_value_bytes ).transact(call_params) @@ -29,7 +28,7 @@ def set_data(self, field_label, field_value, wait_for_receipt=True): return tx def add_erc20_deployer(self, address, wait_for_receipt=True): - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.addToCreateERC20List( self.config.w3.to_checksum_address(address) ).transact(call_params) @@ -40,7 +39,7 @@ def add_erc20_deployer(self, address, wait_for_receipt=True): def set_ddo(self, ddo, wait_for_receipt=True): js = json.dumps(ddo) stored_ddo = Web3.to_bytes(text=js) - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.setMetaData( 1, "", @@ -57,10 +56,12 @@ def set_ddo(self, ddo, wait_for_receipt=True): def add_to_create_erc20_list( self, addr: str, wait_for_receipt=True ) -> Union[HexBytes, TxReceipt]: - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.addToCreateERC20List(addr).transact( call_params ) + if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) diff --git a/pdr_backend/contract/dfrewards.py b/pdr_backend/contract/dfrewards.py index be9c00fe7..4878f847a 100644 --- a/pdr_backend/contract/dfrewards.py +++ b/pdr_backend/contract/dfrewards.py @@ -2,7 +2,6 @@ from pdr_backend.contract.base_contract import BaseContract from pdr_backend.util.mathutil import from_wei -from pdr_backend.util.networkutil import tx_call_params @enforce_types @@ -11,12 +10,14 @@ def __init__(self, web3_pp, address: str): super().__init__(web3_pp, address, "DFRewards") def claim_rewards(self, user_addr: str, token_addr: str, wait_for_receipt=True): - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.claimFor(user_addr, token_addr).transact( call_params ) + if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) def get_claimable_rewards(self, user_addr: str, token_addr: str) -> float: diff --git a/pdr_backend/contract/erc721_factory.py b/pdr_backend/contract/erc721_factory.py index 54b213026..b26d05870 100644 --- a/pdr_backend/contract/erc721_factory.py +++ b/pdr_backend/contract/erc721_factory.py @@ -3,25 +3,28 @@ from pdr_backend.contract.base_contract import BaseContract from pdr_backend.util.contract import get_address -from pdr_backend.util.networkutil import tx_call_params @enforce_types class Erc721Factory(BaseContract): def __init__(self, web3_pp): address = get_address(web3_pp, "ERC721Factory") + if not address: raise ValueError("Cannot figure out Erc721Factory address") + super().__init__(web3_pp, address, "ERC721Factory") def createNftWithErc20WithFixedRate(self, NftCreateData, ErcCreateData, FixedData): - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.createNftWithErc20WithFixedRate( NftCreateData, ErcCreateData, FixedData ).transact(call_params) receipt = self.config.w3.eth.wait_for_transaction_receipt(tx) + if receipt["status"] != 1: raise ValueError(f"createNftWithErc20WithFixedRate failed in {tx.hex()}") + # print(receipt) logs_nft = self.contract_instance.events.NFTCreated().process_receipt( receipt, errors=DISCARD diff --git a/pdr_backend/contract/predictoor_batcher.py b/pdr_backend/contract/predictoor_batcher.py index 8259e41a3..385c2eeb5 100644 --- a/pdr_backend/contract/predictoor_batcher.py +++ b/pdr_backend/contract/predictoor_batcher.py @@ -5,7 +5,6 @@ from pdr_backend.contract.base_contract import BaseContract from pdr_backend.ppss.web3_pp import Web3PP -from pdr_backend.util.networkutil import tx_call_params class PredictoorBatcher(BaseContract): @@ -29,12 +28,14 @@ def consume_multiple( token_addr: str, wait_for_receipt=True, ): - call_params = tx_call_params(self.web3_pp, gas=14_000_000) + call_params = self.web3_pp.tx_call_params(gas=14_000_000) tx = self.contract_instance.functions.consumeMultiple( addresses, times, token_addr ).transact(call_params) + if not wait_for_receipt: return tx + return self.w3.eth.wait_for_transaction_receipt(tx) @enforce_types @@ -46,12 +47,14 @@ def submit_truevals_contracts( cancelRounds: List[List[bool]], wait_for_receipt=True, ): - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.submitTruevalContracts( contract_addrs, epoch_starts, trueVals, cancelRounds ).transact(call_params) + if not wait_for_receipt: return tx + return self.w3.eth.wait_for_transaction_receipt(tx) @enforce_types @@ -63,12 +66,14 @@ def submit_truevals( cancelRounds: List[bool], wait_for_receipt=True, ): - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.submitTruevals( contract_addr, epoch_starts, trueVals, cancelRounds ).transact(call_params) + if not wait_for_receipt: return tx + return self.w3.eth.wait_for_transaction_receipt(tx) diff --git a/pdr_backend/contract/predictoor_contract.py b/pdr_backend/contract/predictoor_contract.py index 3a70b1d39..00599c4d6 100644 --- a/pdr_backend/contract/predictoor_contract.py +++ b/pdr_backend/contract/predictoor_contract.py @@ -8,12 +8,6 @@ from pdr_backend.contract.token import Token from pdr_backend.util.constants import MAX_UINT, ZERO_ADDRESS from pdr_backend.util.mathutil import from_wei, string_to_bytes32, to_wei -from pdr_backend.util.networkutil import ( - get_max_gas, - is_sapphire_network, - send_encrypted_tx, - tx_call_params, -) @enforce_types @@ -59,7 +53,7 @@ def buy_and_start_subscription(self, gasLimit=None, wait_for_receipt=True): print(" Approve spend OCEAN: done") # buy 1 DT - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() orderParams = ( # OrderParams self.config.owner, # consumer 0, # serviceIndex @@ -94,9 +88,11 @@ def buy_and_start_subscription(self, gasLimit=None, wait_for_receipt=True): orderParams, freParams ).estimate_gas(call_params) except Exception as e: - print(f" Estimate gasLimit had error in estimate_gas(): {e}") - print(" Because of error, use get_max_gas() as workaround") - gasLimit = get_max_gas(self.config) + print( + f" Estimate gasLimit had error in estimate_gas(): {e}" + " Because of error, use get_max_gas() as workaround" + ) + gasLimit = self.config.get_max_gas() assert gasLimit is not None, "should have non-None gasLimit by now" print(f" Estimate gasLimit: done. gasLimit={gasLimit}") call_params["gas"] = gasLimit + 1 @@ -113,8 +109,7 @@ def buy_and_start_subscription(self, gasLimit=None, wait_for_receipt=True): print(" buyFromFreAndOrder: waited around, it's done") return tx except Exception as e: - print(" buyFromFreAndOrder hit an error:") - print(e) + print(f" buyFromFreAndOrder hit an error: {e}") return None def buy_many(self, n_to_buy: int, gasLimit=None, wait_for_receipt=False): @@ -195,7 +190,7 @@ def get_agg_predval(self, timestamp: int) -> Tuple[float, float]: denom - denominator = total # OCEAN staked ("") """ auth = self.config.get_auth_signature() - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() (nom_wei, denom_wei) = self.contract_instance.functions.getAggPredval( timestamp, auth ).call(call_params) @@ -203,13 +198,15 @@ def get_agg_predval(self, timestamp: int) -> Tuple[float, float]: def payout_multiple(self, slots: List[int], wait_for_receipt: bool = True): """Claims the payout for given slots""" - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() try: tx = self.contract_instance.functions.payoutMultiple( slots, self.config.owner ).transact(call_params) + if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) except Exception as e: print(e) @@ -217,13 +214,15 @@ def payout_multiple(self, slots: List[int], wait_for_receipt: bool = True): def payout(self, slot, wait_for_receipt=False): """Claims the payout for one slot""" - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() try: tx = self.contract_instance.functions.payout( slot, self.config.owner ).transact(call_params) + if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) except Exception as e: print(e) @@ -273,29 +272,12 @@ def submit_prediction( print("Error while approving the contract to spend tokens:", e) return None - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() try: txhash = None - if is_sapphire_network(self.config.w3.eth.chain_id): - self.contract_instance.encodeABI( - fn_name="submitPredval", - args=[predicted_value, stake_amt_wei, prediction_ts], - ) - sender = self.config.owner - receiver = self.contract_instance.address - pk = self.config.account.key.hex()[2:] - res, txhash = send_encrypted_tx( - self.contract_instance, - "submitPredval", - [predicted_value, stake_amt_wei, prediction_ts], - pk, - sender, - receiver, - self.config.rpc_url, - 0, - 1000000, - 0, - 0, + if self.config.is_sapphire: + res, txhash = self.send_encrypted_tx( + "submitPredval", [predicted_value, stake_amt_wei, prediction_ts] ) print("Encrypted transaction status code:", res) else: @@ -305,8 +287,10 @@ def submit_prediction( txhash = tx.hex() self.last_allowance -= stake_amt_wei print(f"Submitted prediction, txhash: {txhash}") + if not wait_for_receipt: return txhash + return self.config.w3.eth.wait_for_transaction_receipt(txhash) except Exception as e: print(e) @@ -331,25 +315,28 @@ def submit_trueval(self, trueval, timestamp, cancel_round, wait_for_receipt=True Can only be called by the owner. Returns the hash of the transaction. """ - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.submitTrueVal( timestamp, trueval, cancel_round ).transact(call_params) print(f"Submit trueval: txhash={tx.hex()}") - if not wait_for_receipt: - return tx - tx = self.config.w3.eth.wait_for_transaction_receipt(tx) + + if wait_for_receipt: + tx = self.config.w3.eth.wait_for_transaction_receipt(tx) + return tx def redeem_unused_slot_revenue(self, timestamp, wait_for_receipt=True): """Redeem unused slot revenue.""" - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() try: tx = self.contract_instance.functions.redeemUnusedSlotRevenue( timestamp ).transact(call_params) + if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) except Exception as e: print(e) diff --git a/pdr_backend/contract/test/test_base_contract.py b/pdr_backend/contract/test/test_base_contract.py index c4a177bb1..4b8ece71c 100644 --- a/pdr_backend/contract/test/test_base_contract.py +++ b/pdr_backend/contract/test/test_base_contract.py @@ -1,3 +1,6 @@ +import os +from unittest.mock import Mock + import pytest from enforce_typing import enforce_types @@ -5,6 +8,13 @@ from pdr_backend.util.contract import get_address +@pytest.fixture +def mock_send_encrypted_sapphire_tx(monkeypatch): + mock_function = Mock(return_value=(0, "dummy_tx_hash")) + monkeypatch.setattr("sapphirepy.wrapper.send_encrypted_sapphire_tx", mock_function) + return mock_function + + @enforce_types def test_base_contract(web3_pp, web3_config): OCEAN_address = get_address(web3_pp, "Ocean") @@ -16,3 +26,57 @@ def test_base_contract(web3_pp, web3_config): web3_config = web3_pp.web3_config with pytest.raises(ValueError): Token(web3_config, OCEAN_address) + + +@enforce_types +def test_send_encrypted_tx( + mock_send_encrypted_sapphire_tx, # pylint: disable=redefined-outer-name + ocean_token, + web3_pp, +): + OCEAN_address = get_address(web3_pp, "Ocean") + contract = Token(web3_pp, OCEAN_address) + + # Set up dummy return value for the mocked function + mock_send_encrypted_sapphire_tx.return_value = ( + 0, + "dummy_tx_hash", + ) + + # Sample inputs for send_encrypted_tx + function_name = "transfer" + args = [web3_pp.web3_config.owner, 100] + sender = web3_pp.web3_config.owner + receiver = web3_pp.web3_config.w3.eth.accounts[1] + rpc_url = "http://localhost:8545" + value = 0 + gasLimit = 10000000 + gasCost = 0 + nonce = 0 + pk = os.getenv("PRIVATE_KEY") + + tx_hash, encrypted_data = contract.send_encrypted_tx( + function_name, + args, + sender, + receiver, + pk, + value, + gasLimit, + gasCost, + nonce, + ) + assert tx_hash == 0 + assert encrypted_data == "dummy_tx_hash" + + mock_send_encrypted_sapphire_tx.assert_called_once_with( + pk, + sender, + receiver, + rpc_url, + value, + gasLimit, + ocean_token.contract_instance.encodeABI(fn_name=function_name, args=args), + gasCost, + nonce, + ) diff --git a/pdr_backend/contract/test/test_erc721_factory.py b/pdr_backend/contract/test/test_erc721_factory.py index f366b0cb0..183b8c206 100644 --- a/pdr_backend/contract/test/test_erc721_factory.py +++ b/pdr_backend/contract/test/test_erc721_factory.py @@ -1,3 +1,6 @@ +from unittest.mock import Mock, patch + +import pytest from enforce_typing import enforce_types from pdr_backend.contract.erc721_factory import Erc721Factory @@ -53,3 +56,19 @@ def test_Erc721Factory(web3_pp, web3_config): assert len(logs_nft) > 0 assert len(logs_erc) > 0 + + config = Mock() + receipt = {"status": 0} + config.w3.eth.wait_for_transaction_receipt.return_value = receipt + + with patch.object(factory, "config") as mock_config: + mock_config.return_value = config + with pytest.raises(ValueError): + factory.createNftWithErc20WithFixedRate(nft_data, erc_data, fre_data) + + +@enforce_types +def test_Erc721Factory_no_address(web3_pp): + with patch("pdr_backend.contract.erc721_factory.get_address", return_value=None): + with pytest.raises(ValueError): + Erc721Factory(web3_pp) diff --git a/pdr_backend/contract/test/test_predictoor_contract.py b/pdr_backend/contract/test/test_predictoor_contract.py index d93b0b27f..eed413bae 100644 --- a/pdr_backend/contract/test/test_predictoor_contract.py +++ b/pdr_backend/contract/test/test_predictoor_contract.py @@ -1,3 +1,6 @@ +from unittest.mock import Mock + +import pytest from enforce_typing import enforce_types from pytest import approx @@ -28,11 +31,19 @@ def test_buy_and_start_subscription(predictoor_contract): assert is_valid_sub +@enforce_types +def test_buy_and_start_subscription_empty(predictoor_contract_empty): + with pytest.raises(ValueError): + assert predictoor_contract_empty.buy_and_start_subscription() + + @enforce_types def test_buy_many(predictoor_contract): receipts = predictoor_contract.buy_many(2, None, True) assert len(receipts) == 2 + assert predictoor_contract.buy_many(0, None, True) is None + @enforce_types def test_get_exchanges(predictoor_contract): @@ -53,6 +64,13 @@ def test_get_price(predictoor_contract): assert price / 1e18 == approx(3.603) +@enforce_types +def test_get_price_no_exchanges(predictoor_contract_empty): + predictoor_contract_empty.get_exchanges = Mock(return_value=[]) + with pytest.raises(ValueError): + predictoor_contract_empty.get_price() + + @enforce_types def test_get_current_epoch(predictoor_contract): current_epoch = predictoor_contract.get_current_epoch() diff --git a/pdr_backend/contract/test/test_token.py b/pdr_backend/contract/test/test_token.py index 11989516d..bd541e0de 100644 --- a/pdr_backend/contract/test/test_token.py +++ b/pdr_backend/contract/test/test_token.py @@ -1,10 +1,10 @@ import time +from unittest.mock import patch from enforce_typing import enforce_types -from pdr_backend.contract.token import Token +from pdr_backend.contract.token import NativeToken, Token from pdr_backend.util.contract import get_address -from pdr_backend.util.networkutil import tx_call_params @enforce_types @@ -16,7 +16,7 @@ def test_token(web3_pp, web3_config): owner_addr = web3_config.owner alice = accounts[1] - call_params = tx_call_params(web3_pp) + call_params = web3_pp.tx_call_params() token.contract_instance.functions.mint(owner_addr, 1000000000).transact(call_params) allowance_start = token.allowance(owner_addr, alice) @@ -29,3 +29,16 @@ def test_token(web3_pp, web3_config): token.transfer(alice, 100, owner_addr) balance_end = token.balanceOf(alice) assert balance_end - balance_start == 100 + + +@enforce_types +def test_native_token(web3_pp): + token = NativeToken(web3_pp) + assert token.w3 + + owner = web3_pp.web3_config.owner + assert token.balanceOf(owner) + + with patch("web3.eth.Eth.send_transaction") as mock: + token.transfer(owner, 100, "0x123", False) + assert mock.called diff --git a/pdr_backend/contract/test/test_wrapped_token.py b/pdr_backend/contract/test/test_wrapped_token.py new file mode 100644 index 000000000..0d8b94379 --- /dev/null +++ b/pdr_backend/contract/test/test_wrapped_token.py @@ -0,0 +1,21 @@ +from unittest.mock import Mock, patch + +from enforce_typing import enforce_types + +from pdr_backend.contract.wrapped_token import WrappedToken +from pdr_backend.util.contract import get_address + + +@enforce_types +def test_native_token(web3_pp): + token_address = get_address(web3_pp, "Ocean") + mock_wrapped_contract = Mock() + mock_transaction = Mock() + mock_transaction.transact.return_value = "mock_tx" + mock_wrapped_contract.functions.withdraw.return_value = mock_transaction + + with patch("web3.eth.Eth.contract") as mock: + mock.return_value = mock_wrapped_contract + token = WrappedToken(web3_pp, token_address) + + assert token.withdraw(100, False) == "mock_tx" diff --git a/pdr_backend/contract/token.py b/pdr_backend/contract/token.py index 172ffc763..acec042cf 100644 --- a/pdr_backend/contract/token.py +++ b/pdr_backend/contract/token.py @@ -2,7 +2,6 @@ from web3.types import TxParams, Wei from pdr_backend.contract.base_contract import BaseContract -from pdr_backend.util.networkutil import tx_call_params, tx_gas_price @enforce_types @@ -17,7 +16,7 @@ def balanceOf(self, account): return self.contract_instance.functions.balanceOf(account).call() def transfer(self, to: str, amount: int, sender, wait_for_receipt=True): - gas_price = tx_gas_price(self.web3_pp) + gas_price = self.web3_pp.tx_gas_price() call_params = {"from": sender, "gasPrice": gas_price} tx = self.contract_instance.functions.transfer(to, int(amount)).transact( call_params @@ -25,16 +24,19 @@ def transfer(self, to: str, amount: int, sender, wait_for_receipt=True): if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) def approve(self, spender, amount, wait_for_receipt=True): - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() # print(f"Approving {amount} for {spender} on contract {self.contract_address}") tx = self.contract_instance.functions.approve(spender, amount).transact( call_params ) + if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) @@ -53,7 +55,7 @@ def balanceOf(self, account): @enforce_types def transfer(self, to: str, amount: int, sender, wait_for_receipt=True): - gas_price = tx_gas_price(self.web3_pp) + gas_price = self.web3_pp.tx_gas_price() call_params: TxParams = { "from": sender, "gas": 25000, diff --git a/pdr_backend/contract/wrapped_token.py b/pdr_backend/contract/wrapped_token.py index b0d38e168..c06b62a7a 100644 --- a/pdr_backend/contract/wrapped_token.py +++ b/pdr_backend/contract/wrapped_token.py @@ -1,5 +1,4 @@ from pdr_backend.contract.token import Token -from pdr_backend.util.networkutil import tx_call_params class WrappedToken(Token): @@ -24,7 +23,7 @@ def withdraw(self, amount: int, wait_for_receipt=True): """ Converts Wrapped Token to Token, amount is in wei. """ - call_params = tx_call_params(self.web3_pp) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance_wrapped.functions.withdraw(amount).transact( call_params ) diff --git a/pdr_backend/ppss/test/test_web3_pp.py b/pdr_backend/ppss/test/test_web3_pp.py index f22173d02..25af89305 100644 --- a/pdr_backend/ppss/test/test_web3_pp.py +++ b/pdr_backend/ppss/test/test_web3_pp.py @@ -205,3 +205,32 @@ def test_inplace_mocks(monkeypatch): c = mock_predictoor_contract(feed.address) inplace_mock_get_contracts(web3_pp, feed, c) + + +@enforce_types +def test_tx_gas_price__and__tx_call_params(): + web3_pp = mock_web3_pp("sapphire-testnet") + eth_mock = Mock() + eth_mock.gas_price = 12 + web3_pp.web3_config.w3.eth = eth_mock + web3_pp.web3_config.owner = "0xowner" + + web3_pp.network = "sapphire-testnet" + assert web3_pp.tx_gas_price() == 12 + assert web3_pp.tx_call_params() == {"from": "0xowner", "gasPrice": 12} + + web3_pp.network = "sapphire-mainnet" + assert web3_pp.tx_gas_price() == 12 + + web3_pp.network = "development" + assert web3_pp.tx_gas_price() == 0 + assert web3_pp.tx_call_params() == {"from": "0xowner", "gasPrice": 0} + + web3_pp.network = "barge-pytest" + assert web3_pp.tx_gas_price() == 0 + + web3_pp.network = "foo" + with pytest.raises(ValueError): + web3_pp.tx_gas_price() + with pytest.raises(ValueError): + web3_pp.tx_call_params() diff --git a/pdr_backend/ppss/web3_pp.py b/pdr_backend/ppss/web3_pp.py index be32fb7de..aeef155ed 100644 --- a/pdr_backend/ppss/web3_pp.py +++ b/pdr_backend/ppss/web3_pp.py @@ -144,6 +144,27 @@ def get_pending_slots( allowed_feeds=allowed_feeds, ) + @enforce_types + def tx_call_params(self, gas=None) -> dict: + call_params = { + "from": self.web3_config.owner, + "gasPrice": self.tx_gas_price(), + } + if gas is not None: + call_params["gas"] = gas + return call_params + + @enforce_types + def tx_gas_price(self) -> int: + """Return gas price for use in call_params of transaction calls.""" + network = self.network + if network in ["sapphire-testnet", "sapphire-mainnet"]: + return self.web3_config.w3.eth.gas_price + # return 100000000000 + if network in ["development", "barge-predictoor-bot", "barge-pytest"]: + return 0 + raise ValueError(f"Unknown network {network}") + # ========================================================================= # utilities for testing diff --git a/pdr_backend/util/csvs.py b/pdr_backend/util/csvs.py index e4267aed5..5b4e6f31e 100644 --- a/pdr_backend/util/csvs.py +++ b/pdr_backend/util/csvs.py @@ -45,7 +45,12 @@ def check_and_create_dir(dir_path: str): @enforce_types -def save_prediction_csv(all_predictions: List[Prediction], csv_output_dir: str): +def _save_prediction_csv( + all_predictions: List[Prediction], + csv_output_dir: str, + headers: List, + attribute_names: List, +): check_and_create_dir(csv_output_dir) data = generate_prediction_data_structure(all_predictions) @@ -56,58 +61,43 @@ def save_prediction_csv(all_predictions: List[Prediction], csv_output_dir: str): with open(filename, "w", newline="") as file: writer = csv.writer(file) - writer.writerow( - ["Predicted Value", "True Value", "Timestamp", "Stake", "Payout"] - ) + writer.writerow(headers) for prediction in predictions: writer.writerow( [ - prediction.prediction, - prediction.trueval, - prediction.timestamp, - prediction.stake, - prediction.payout, + getattr(prediction, attribute_name) + for attribute_name in attribute_names ] ) + print(f"CSV file '{filename}' created successfully.") @enforce_types -def save_analysis_csv(all_predictions: List[Prediction], csv_output_dir: str): - check_and_create_dir(csv_output_dir) - - data = generate_prediction_data_structure(all_predictions) +def save_prediction_csv(all_predictions: List[Prediction], csv_output_dir: str): + _save_prediction_csv( + all_predictions, + csv_output_dir, + ["Predicted Value", "True Value", "Timestamp", "Stake", "Payout"], + ["prediction", "trueval", "timestamp", "stake", "payout"], + ) - for key, predictions in data.items(): - predictions.sort(key=lambda x: x.timestamp) - filename = key_csv_filename_with_dir(csv_output_dir, key) - with open(filename, "w", newline="") as file: - writer = csv.writer(file) - writer.writerow( - [ - "PredictionID", - "Timestamp", - "Slot", - "Stake", - "Wallet", - "Payout", - "True Value", - "Predicted Value", - ] - ) - for prediction in predictions: - writer.writerow( - [ - prediction.ID, - prediction.timestamp, - prediction.slot, - prediction.stake, - prediction.user, - prediction.payout, - prediction.trueval, - prediction.prediction, - ] - ) - print(f"CSV file '{filename}' created successfully.") +@enforce_types +def save_analysis_csv(all_predictions: List[Prediction], csv_output_dir: str): + _save_prediction_csv( + all_predictions, + csv_output_dir, + [ + "PredictionID", + "Timestamp", + "Slot", + "Stake", + "Wallet", + "Payout", + "True Value", + "Predicted Value", + ], + ["ID", "timestamp", "slot", "stake", "user", "payout", "trueval", "prediction"], + ) diff --git a/pdr_backend/util/networkutil.py b/pdr_backend/util/networkutil.py index e6b155c99..75c68c428 100644 --- a/pdr_backend/util/networkutil.py +++ b/pdr_backend/util/networkutil.py @@ -1,15 +1,4 @@ from enforce_typing import enforce_types -from sapphirepy import wrapper - -from pdr_backend.util.constants import ( - SAPPHIRE_MAINNET_CHAINID, - SAPPHIRE_TESTNET_CHAINID, -) - - -@enforce_types -def is_sapphire_network(chain_id: int) -> bool: - return chain_id in [SAPPHIRE_TESTNET_CHAINID, SAPPHIRE_MAINNET_CHAINID] @enforce_types @@ -22,34 +11,6 @@ def get_sapphire_postfix(network: str) -> str: raise ValueError(f"'{network}' is not valid name") -@enforce_types -def send_encrypted_tx( - contract_instance, - function_name, - args, - pk, - sender, - receiver, - rpc_url, - value=0, # in wei - gasLimit=10000000, - gasCost=0, # in wei - nonce=0, -) -> tuple: - data = contract_instance.encodeABI(fn_name=function_name, args=args) - return wrapper.send_encrypted_sapphire_tx( - pk, - sender, - receiver, - rpc_url, - value, - gasLimit, - data, - gasCost, - nonce, - ) - - @enforce_types def get_subgraph_url(network: str) -> str: """ @@ -68,35 +29,3 @@ def get_subgraph_url(network: str) -> str: # pylint: disable=line-too-long return f"https://v4.subgraph.sapphire-{network}.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph" - - -@enforce_types -def tx_gas_price(web3_pp) -> int: - """Return gas price for use in call_params of transaction calls.""" - network = web3_pp.network - if network in ["sapphire-testnet", "sapphire-mainnet"]: - return web3_pp.web3_config.w3.eth.gas_price - # return 100000000000 - if network in ["development", "barge-predictoor-bot", "barge-pytest"]: - return 0 - raise ValueError(f"Unknown network {network}") - - -@enforce_types -def tx_call_params(web3_pp, gas=None) -> dict: - call_params = { - "from": web3_pp.web3_config.owner, - "gasPrice": tx_gas_price(web3_pp), - } - if gas is not None: - call_params["gas"] = gas - return call_params - - -@enforce_types -def get_max_gas(web3_config) -> int: - """Returns max block gas""" - block = web3_config.get_block( - web3_config.w3.eth.block_number, full_transactions=False - ) - return int(block["gasLimit"] * 0.99) diff --git a/pdr_backend/util/test_ganache/test_networkutil.py b/pdr_backend/util/test_ganache/test_networkutil.py index af680f7f0..8cbf79d68 100644 --- a/pdr_backend/util/test_ganache/test_networkutil.py +++ b/pdr_backend/util/test_ganache/test_networkutil.py @@ -1,30 +1,9 @@ -import os -from unittest.mock import Mock - import pytest from enforce_typing import enforce_types -from pdr_backend.util.constants import ( - MAX_UINT, - SAPPHIRE_MAINNET_CHAINID, - SAPPHIRE_TESTNET_CHAINID, -) from pdr_backend.util.networkutil import ( - get_max_gas, get_sapphire_postfix, - is_sapphire_network, - send_encrypted_tx, - tx_call_params, - tx_gas_price, ) -from pdr_backend.util.web3_config import Web3Config - - -@enforce_types -def test_is_sapphire_network(): - assert not is_sapphire_network(0) - assert is_sapphire_network(SAPPHIRE_TESTNET_CHAINID) - assert is_sapphire_network(SAPPHIRE_MAINNET_CHAINID) @enforce_types @@ -44,101 +23,3 @@ def test_get_sapphire_postfix(): for unwanted in unwanteds: with pytest.raises(ValueError): assert get_sapphire_postfix(unwanted) - - -@enforce_types -def test_send_encrypted_tx( - mock_send_encrypted_sapphire_tx, # pylint: disable=redefined-outer-name - ocean_token, - web3_config, -): - # Set up dummy return value for the mocked function - mock_send_encrypted_sapphire_tx.return_value = ( - 0, - "dummy_tx_hash", - ) - # Sample inputs for send_encrypted_tx - function_name = "transfer" - args = [web3_config.owner, 100] - pk = os.getenv("PRIVATE_KEY") - sender = web3_config.owner - receiver = web3_config.w3.eth.accounts[1] - rpc_url = "http://localhost:8545" - value = 0 - gasLimit = 10000000 - gasCost = 0 - nonce = 0 - tx_hash, encrypted_data = send_encrypted_tx( - ocean_token.contract_instance, - function_name, - args, - pk, - sender, - receiver, - rpc_url, - value, - gasLimit, - gasCost, - nonce, - ) - assert tx_hash == 0 - assert encrypted_data == "dummy_tx_hash" - mock_send_encrypted_sapphire_tx.assert_called_once_with( - pk, - sender, - receiver, - rpc_url, - value, - gasLimit, - ocean_token.contract_instance.encodeABI(fn_name=function_name, args=args), - gasCost, - nonce, - ) - - -@pytest.fixture -def mock_send_encrypted_sapphire_tx(monkeypatch): - mock_function = Mock(return_value=(0, "dummy_tx_hash")) - monkeypatch.setattr("sapphirepy.wrapper.send_encrypted_sapphire_tx", mock_function) - return mock_function - - -@enforce_types -def test_tx_gas_price__and__tx_call_params(): - web3_pp = Mock() - web3_pp.web3_config = Mock() - web3_pp.web3_config.owner = "0xowner" - web3_pp.web3_config.w3 = Mock() - web3_pp.web3_config.w3.eth = Mock() - web3_pp.web3_config.w3.eth.gas_price = 12 - - web3_pp.network = "sapphire-testnet" - assert tx_gas_price(web3_pp) == 12 - assert tx_call_params(web3_pp) == {"from": "0xowner", "gasPrice": 12} - - web3_pp.network = "sapphire-mainnet" - assert tx_gas_price(web3_pp) == 12 - - web3_pp.network = "development" - assert tx_gas_price(web3_pp) == 0 - assert tx_call_params(web3_pp) == {"from": "0xowner", "gasPrice": 0} - - web3_pp.network = "barge-pytest" - assert tx_gas_price(web3_pp) == 0 - - web3_pp.network = "foo" - with pytest.raises(ValueError): - tx_gas_price(web3_pp) - with pytest.raises(ValueError): - tx_call_params(web3_pp) - - -@enforce_types -def test_get_max_gas(rpc_url): - private_key = os.getenv("PRIVATE_KEY") - web3_config = Web3Config(rpc_url=rpc_url, private_key=private_key) - max_gas = get_max_gas(web3_config) - assert 0 < max_gas < MAX_UINT - - target_max_gas = int(web3_config.get_block("latest").gasLimit * 0.99) - assert max_gas == target_max_gas diff --git a/pdr_backend/util/test_ganache/test_web3_config.py b/pdr_backend/util/test_ganache/test_web3_config.py index df0374057..5c4aceb79 100644 --- a/pdr_backend/util/test_ganache/test_web3_config.py +++ b/pdr_backend/util/test_ganache/test_web3_config.py @@ -3,6 +3,7 @@ import pytest from enforce_typing import enforce_types +from pdr_backend.util.constants import MAX_UINT from pdr_backend.util.web3_config import Web3Config @@ -73,3 +74,14 @@ def test_Web3Config_get_auth_signature(rpc_url): # just a super basic test assert sorted(auth.keys()) == sorted(["userAddress", "v", "r", "s", "validUntil"]) + + +@enforce_types +def test_get_max_gas(rpc_url): + private_key = os.getenv("PRIVATE_KEY") + web3_config = Web3Config(rpc_url=rpc_url, private_key=private_key) + max_gas = web3_config.get_max_gas() + assert 0 < max_gas < MAX_UINT + + target_max_gas = int(web3_config.get_block("latest").gasLimit * 0.99) + assert max_gas == target_max_gas diff --git a/pdr_backend/util/test_noganache/test_csvs.py b/pdr_backend/util/test_noganache/test_csvs.py new file mode 100644 index 000000000..35a5f91e0 --- /dev/null +++ b/pdr_backend/util/test_noganache/test_csvs.py @@ -0,0 +1,58 @@ +import csv +import os + +from pdr_backend.subgraph.prediction import mock_daily_predictions +from pdr_backend.util.csvs import save_analysis_csv, save_prediction_csv + + +def test_save_analysis_csv(tmpdir): + predictions = mock_daily_predictions() + key = ( + predictions[0].pair.replace("/", "-") + + predictions[0].timeframe + + predictions[0].source + ) + save_analysis_csv(predictions, str(tmpdir)) + + with open(os.path.join(str(tmpdir), key + ".csv")) as f: + data = csv.DictReader(f) + data_rows = list(data) + + assert data_rows[0]["Predicted Value"] == str(predictions[0].prediction) + assert data_rows[0]["True Value"] == str(predictions[0].trueval) + assert data_rows[0]["Timestamp"] == str(predictions[0].timestamp) + assert list(data_rows[0].keys()) == [ + "PredictionID", + "Timestamp", + "Slot", + "Stake", + "Wallet", + "Payout", + "True Value", + "Predicted Value", + ] + + +def test_save_prediction_csv(tmpdir): + predictions = mock_daily_predictions() + key = ( + predictions[0].pair.replace("/", "-") + + predictions[0].timeframe + + predictions[0].source + ) + save_prediction_csv(predictions, str(tmpdir)) + + with open(os.path.join(str(tmpdir), key + ".csv")) as f: + data = csv.DictReader(f) + data_rows = list(row for row in data) + + assert data_rows[0]["Predicted Value"] == str(predictions[0].prediction) + assert data_rows[0]["True Value"] == str(predictions[0].trueval) + assert data_rows[0]["Timestamp"] == str(predictions[0].timestamp) + assert list(data_rows[0].keys()) == [ + "Predicted Value", + "True Value", + "Timestamp", + "Stake", + "Payout", + ] diff --git a/pdr_backend/util/web3_config.py b/pdr_backend/util/web3_config.py index 2b89089a2..481b075b3 100644 --- a/pdr_backend/util/web3_config.py +++ b/pdr_backend/util/web3_config.py @@ -13,6 +13,10 @@ from web3.types import BlockData from pdr_backend.util.constants import WEB3_MAX_TRIES +from pdr_backend.util.constants import ( + SAPPHIRE_MAINNET_CHAINID, + SAPPHIRE_TESTNET_CHAINID, +) _KEYS = KeyAPI(NativeECCBackend) @@ -76,3 +80,16 @@ def get_auth_signature(self): "validUntil": valid_until, } return auth + + @property + def is_sapphire(self): + return self.w3.eth.chain_id in [ + SAPPHIRE_TESTNET_CHAINID, + SAPPHIRE_MAINNET_CHAINID, + ] + + @enforce_types + def get_max_gas(self) -> int: + """Returns max block gas""" + block = self.get_block(self.w3.eth.block_number, full_transactions=False) + return int(block["gasLimit"] * 0.99)