From 33f8f4936d3185920fe59b404792c91b614e30ae Mon Sep 17 00:00:00 2001 From: protolambda Date: Fri, 20 Mar 2020 20:38:36 +0100 Subject: [PATCH] Fix base-reward memoization bug, improve memoization with LRU, and improve misc rewards test --- Makefile | 4 ++ setup.py | 40 ++++++++++++++----- tests/core/pyspec/eth2spec/test/context.py | 7 +++- .../test_process_rewards_and_penalties.py | 31 ++++++++------ 4 files changed, 58 insertions(+), 24 deletions(-) diff --git a/Makefile b/Makefile index b468e648c6..e8f3d21bc5 100644 --- a/Makefile +++ b/Makefile @@ -77,6 +77,10 @@ test: pyspec . venv/bin/activate; cd $(PY_SPEC_DIR); \ python -m pytest -n 4 --cov=eth2spec.phase0.spec --cov=eth2spec.phase1.spec --cov-report="html:$(COV_HTML_OUT)" --cov-branch eth2spec +find_test: pyspec + . venv/bin/activate; cd $(PY_SPEC_DIR); \ + python -m pytest -k=$(K) --cov=eth2spec.phase0.spec --cov=eth2spec.phase1.spec --cov-report="html:$(COV_HTML_OUT)" --cov-branch eth2spec + citest: pyspec mkdir -p tests/core/pyspec/test-reports/eth2spec; . venv/bin/activate; cd $(PY_SPEC_DIR); \ python -m pytest -n 4 --junitxml=eth2spec/test_results.xml eth2spec diff --git a/setup.py b/setup.py index 906e302404..9b5a6cabc5 100644 --- a/setup.py +++ b/setup.py @@ -92,6 +92,8 @@ def get_spec(file_name: str) -> SpecObject: field, ) +from lru import LRU + from eth2spec.utils.ssz.ssz_impl import hash_tree_root from eth2spec.utils.ssz.ssz_typing import ( View, boolean, Container, List, Vector, uint64, @@ -114,6 +116,8 @@ def get_spec(file_name: str) -> SpecObject: field, ) +from lru import LRU + from eth2spec.utils.ssz.ssz_impl import hash_tree_root from eth2spec.utils.ssz.ssz_typing import ( View, boolean, Container, List, Vector, uint64, uint8, bit, @@ -152,8 +156,8 @@ def hash(x: bytes) -> Bytes32: # type: ignore return hash_cache[x] -def cache_this(key_fn, value_fn): # type: ignore - cache_dict = {} # type: ignore +def cache_this(key_fn, value_fn, lru_size): # type: ignore + cache_dict = LRU(size=lru_size) def wrapper(*args, **kw): # type: ignore key = key_fn(*args, **kw) @@ -164,35 +168,50 @@ def wrapper(*args, **kw): # type: ignore return wrapper +_compute_shuffled_index = compute_shuffled_index +compute_shuffled_index = cache_this( + lambda index, index_count, seed: (index, index_count, seed), + _compute_shuffled_index, lru_size=SLOTS_PER_EPOCH * 3) + +_get_total_active_balance = get_total_active_balance +get_total_active_balance = cache_this( + lambda state: (state.validators.hash_tree_root(), state.slot), + _get_total_active_balance, lru_size=10) + _get_base_reward = get_base_reward get_base_reward = cache_this( - lambda state, index: (state.validators.hash_tree_root(), state.slot), - _get_base_reward) + lambda state, index: (state.validators.hash_tree_root(), state.slot, index), + _get_base_reward, lru_size=10) _get_committee_count_at_slot = get_committee_count_at_slot get_committee_count_at_slot = cache_this( lambda state, epoch: (state.validators.hash_tree_root(), epoch), - _get_committee_count_at_slot) + _get_committee_count_at_slot, lru_size=SLOTS_PER_EPOCH * 3) _get_active_validator_indices = get_active_validator_indices get_active_validator_indices = cache_this( lambda state, epoch: (state.validators.hash_tree_root(), epoch), - _get_active_validator_indices) + _get_active_validator_indices, lru_size=3) _get_beacon_committee = get_beacon_committee get_beacon_committee = cache_this( lambda state, slot, index: (state.validators.hash_tree_root(), state.randao_mixes.hash_tree_root(), slot, index), - _get_beacon_committee) + _get_beacon_committee, lru_size=SLOTS_PER_EPOCH * MAX_COMMITTEES_PER_SLOT * 3) _get_matching_target_attestations = get_matching_target_attestations get_matching_target_attestations = cache_this( lambda state, epoch: (state.hash_tree_root(), epoch), - _get_matching_target_attestations) + _get_matching_target_attestations, lru_size=10) _get_matching_head_attestations = get_matching_head_attestations get_matching_head_attestations = cache_this( lambda state, epoch: (state.hash_tree_root(), epoch), - _get_matching_head_attestations)''' + _get_matching_head_attestations, lru_size=10) + +_get_attesting_indices = get_attesting_indices +get_attesting_indices = cache_this(lambda state, data, bits: + (state.validators.hash_tree_root(), data.hash_tree_root(), bits.hash_tree_root()), + _get_attesting_indices, lru_size=SLOTS_PER_EPOCH * MAX_COMMITTEES_PER_SLOT * 3)''' def objects_to_spec(spec_object: SpecObject, imports: str, fork: str) -> str: @@ -481,6 +500,7 @@ def run(self): "py_ecc==2.0.0", "dataclasses==0.6", "remerkleable==0.1.12", - "ruamel.yaml==0.16.5" + "ruamel.yaml==0.16.5", + "lru-dict==1.1.6" ] ) diff --git a/tests/core/pyspec/eth2spec/test/context.py b/tests/core/pyspec/eth2spec/test/context.py index 5338ccb9d4..6e24c4cfe5 100644 --- a/tests/core/pyspec/eth2spec/test/context.py +++ b/tests/core/pyspec/eth2spec/test/context.py @@ -6,6 +6,7 @@ from .utils import vector_test, with_meta_tags +from random import Random from typing import Any, Callable, Sequence, TypedDict, Protocol from importlib import reload @@ -100,8 +101,10 @@ def misc_balances(spec): Usage: `@with_custom_state(balances_fn=misc_balances, ...)` """ num_validators = spec.SLOTS_PER_EPOCH * 8 - num_misc_validators = spec.SLOTS_PER_EPOCH - return [spec.MAX_EFFECTIVE_BALANCE] * num_validators + [spec.MIN_DEPOSIT_AMOUNT] * num_misc_validators + balances = [spec.MAX_EFFECTIVE_BALANCE * 2 * i // num_validators for i in range(num_validators)] + rng = Random(1234) + rng.shuffle(balances) + return balances def single_phase(fn): diff --git a/tests/core/pyspec/eth2spec/test/phase_0/epoch_processing/test_process_rewards_and_penalties.py b/tests/core/pyspec/eth2spec/test/phase_0/epoch_processing/test_process_rewards_and_penalties.py index 1110337999..692260585d 100644 --- a/tests/core/pyspec/eth2spec/test/phase_0/epoch_processing/test_process_rewards_and_penalties.py +++ b/tests/core/pyspec/eth2spec/test/phase_0/epoch_processing/test_process_rewards_and_penalties.py @@ -1,8 +1,6 @@ -from copy import deepcopy - from eth2spec.test.context import ( spec_state_test, with_all_phases, spec_test, - misc_balances, with_custom_state, default_activation_threshold, + misc_balances, with_custom_state, single_phase, ) from eth2spec.test.helpers.state import ( @@ -24,7 +22,7 @@ def run_process_rewards_and_penalties(spec, state): @with_all_phases @spec_state_test def test_genesis_epoch_no_attestations_no_penalties(spec, state): - pre_state = deepcopy(state) + pre_state = state.copy() assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH @@ -52,7 +50,7 @@ def test_genesis_epoch_full_attestations_no_rewards(spec, state): # ensure has not cross the epoch boundary assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH - pre_state = deepcopy(state) + pre_state = state.copy() yield from run_process_rewards_and_penalties(spec, state) @@ -84,7 +82,7 @@ def prepare_state_with_full_attestations(spec, state): def test_full_attestations(spec, state): attestations = prepare_state_with_full_attestations(spec, state) - pre_state = deepcopy(state) + pre_state = state.copy() yield from run_process_rewards_and_penalties(spec, state) @@ -122,18 +120,19 @@ def test_full_attestations_random_incorrect_fields(spec, state): @with_all_phases @spec_test -@with_custom_state(balances_fn=misc_balances, threshold_fn=default_activation_threshold) +@with_custom_state(balances_fn=misc_balances, threshold_fn=lambda spec: spec.MAX_EFFECTIVE_BALANCE // 2) @single_phase def test_full_attestations_misc_balances(spec, state): attestations = prepare_state_with_full_attestations(spec, state) - pre_state = deepcopy(state) + pre_state = state.copy() yield from run_process_rewards_and_penalties(spec, state) attesting_indices = spec.get_unslashed_attesting_indices(state, attestations) assert len(attesting_indices) > 0 assert len(attesting_indices) != len(pre_state.validators) + assert any(v.effective_balance != spec.MAX_EFFECTIVE_BALANCE for v in state.validators) for index in range(len(pre_state.validators)): if index in attesting_indices: assert state.balances[index] > pre_state.balances[index] @@ -141,13 +140,21 @@ def test_full_attestations_misc_balances(spec, state): assert state.balances[index] < pre_state.balances[index] else: assert state.balances[index] == pre_state.balances[index] + # Check if base rewards are consistent with effective balance. + brs = {} + for index in attesting_indices: + br = spec.get_base_reward(state, index) + if br in brs: + assert brs[br] == state.validators[index].effective_balance + else: + brs[br] = state.validators[index].effective_balance @with_all_phases @spec_state_test def test_no_attestations_all_penalties(spec, state): next_epoch(spec, state) - pre_state = deepcopy(state) + pre_state = state.copy() assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1 @@ -173,8 +180,8 @@ def test_duplicate_attestation(spec, state): assert len(participants) > 0 - single_state = deepcopy(state) - dup_state = deepcopy(state) + single_state = state.copy() + dup_state = state.copy() inclusion_slot = state.slot + spec.MIN_ATTESTATION_INCLUSION_DELAY add_attestations_to_state(spec, single_state, [attestation], inclusion_slot) @@ -220,7 +227,7 @@ def test_attestations_some_slashed(spec, state): assert spec.compute_epoch_at_slot(state.slot) == spec.GENESIS_EPOCH + 1 assert len(state.previous_epoch_attestations) == spec.SLOTS_PER_EPOCH - pre_state = deepcopy(state) + pre_state = state.copy() yield from run_process_rewards_and_penalties(spec, state)