Skip to content

Commit

Permalink
Eng 248 policy tests speedup (#12248)
Browse files Browse the repository at this point in the history
* Cleaned up TED Policy test matrix by removing redundant, long-running tests

* removed loading in test cases that don't manipulate the policy object
  • Loading branch information
twerkmeister authored Apr 11, 2023
1 parent b894719 commit fb1a806
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 24 deletions.
88 changes: 79 additions & 9 deletions tests/core/policies/test_ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,14 +578,78 @@ def test_empty_featurizer_configs(
assert isinstance(featurizer.state_featurizer, state_featurizer)


class TestTEDPolicyMargin(TestTEDPolicy):
class TestTEDPolicyConfigurationOptions:
"""Helper class to skip redundant and long-running tests in subclasses."""

@pytest.mark.parametrize("should_finetune", [False])
@pytest.mark.skip()
def test_persist_and_load(
self,
trained_policy: Policy,
default_domain: Domain,
should_finetune: bool,
stories_path: Text,
model_storage: ModelStorage,
resource: Resource,
execution_context: ExecutionContext,
):
"""This takes long and does not need to be tested for every config change."""
pass

@pytest.mark.skip()
def test_train_model_checkpointing(
self, tmp_path: Path, tmp_path_factory: TempPathFactory
):
"""This takes long and does not need to be tested for every config change."""
pass

@pytest.mark.skip()
def test_doesnt_checkpoint_with_no_checkpointing(
self, tmp_path: Path, tmp_path_factory: TempPathFactory
):
"""This takes long and does not need to be tested for every config change."""
pass

@pytest.mark.skip()
def test_doesnt_checkpoint_with_zero_eval_num_examples(
self, tmp_path: Path, tmp_path_factory: TempPathFactory
):
"""This takes long and does not need to be tested for every config change."""

@pytest.mark.parametrize(
"should_finetune, epoch_override, expected_epoch_value",
[
(
True,
TEDPolicy.get_default_config()[EPOCHS] + 1,
TEDPolicy.get_default_config()[EPOCHS] + 1,
)
],
)
@pytest.mark.skip()
def test_epoch_override_when_loaded(
self,
trained_policy: TEDPolicy,
should_finetune: bool,
epoch_override: int,
expected_epoch_value: int,
resource: Resource,
model_storage: ModelStorage,
execution_context: ExecutionContext,
):
"""This takes long and does not need to be tested for every config change."""
pass


class TestTEDPolicyMargin(TestTEDPolicyConfigurationOptions, TestTEDPolicy):
def _config(
self, config_override: Optional[Dict[Text, Any]] = None
) -> Dict[Text, Any]:
config_override = config_override or {}
return {
**TEDPolicy.get_default_config(),
LOSS_TYPE: "margin",
EPOCHS: 2,
**config_override,
}

Expand Down Expand Up @@ -619,7 +683,7 @@ def test_prediction_on_empty_tracker(
assert min(prediction.probabilities) >= -1.0


class TestTEDPolicyWithEval(TestTEDPolicy):
class TestTEDPolicyWithEval(TestTEDPolicyConfigurationOptions, TestTEDPolicy):
def _config(
self, config_override: Optional[Dict[Text, Any]] = None
) -> Dict[Text, Any]:
Expand All @@ -632,7 +696,7 @@ def _config(
}


class TestTEDPolicyNormalization(TestTEDPolicy):
class TestTEDPolicyNormalization(TestTEDPolicyConfigurationOptions, TestTEDPolicy):
def _config(
self, config_override: Optional[Dict[Text, Any]] = None
) -> Dict[Text, Any]:
Expand Down Expand Up @@ -662,7 +726,7 @@ def test_ranking_length_and_renormalization(
assert sum(predicted_probabilities) == pytest.approx(1)


class TestTEDPolicyLowRankingLength(TestTEDPolicy):
class TestTEDPolicyLowRankingLength(TestTEDPolicyConfigurationOptions, TestTEDPolicy):
def _config(
self, config_override: Optional[Dict[Text, Any]] = None
) -> Dict[Text, Any]:
Expand All @@ -673,7 +737,7 @@ def test_ranking_length(self, trained_policy: TEDPolicy):
assert trained_policy.config[RANKING_LENGTH] == 3


class TestTEDPolicyHighRankingLength(TestTEDPolicy):
class TestTEDPolicyHighRankingLength(TestTEDPolicyConfigurationOptions, TestTEDPolicy):
def _config(
self, config_override: Optional[Dict[Text, Any]] = None
) -> Dict[Text, Any]:
Expand All @@ -684,7 +748,9 @@ def test_ranking_length(self, trained_policy: TEDPolicy):
assert trained_policy.config[RANKING_LENGTH] == 11


class TestTEDPolicyWithStandardFeaturizer(TestTEDPolicy):
class TestTEDPolicyWithStandardFeaturizer(
TestTEDPolicyConfigurationOptions, TestTEDPolicy
):
def _config(
self, config_override: Optional[Dict[Text, Any]] = None
) -> Dict[Text, Any]:
Expand Down Expand Up @@ -733,7 +799,7 @@ def test_featurizer(
assert isinstance(loaded.featurizer.state_featurizer, SingleStateFeaturizer)


class TestTEDPolicyWithMaxHistory(TestTEDPolicy):
class TestTEDPolicyWithMaxHistory(TestTEDPolicyConfigurationOptions, TestTEDPolicy):
def _config(
self, config_override: Optional[Dict[Text, Any]] = None
) -> Dict[Text, Any]:
Expand Down Expand Up @@ -763,7 +829,9 @@ def create_policy(
)


class TestTEDPolicyWithRelativeAttention(TestTEDPolicy):
class TestTEDPolicyWithRelativeAttention(
TestTEDPolicyConfigurationOptions, TestTEDPolicy
):
def _config(
self, config_override: Optional[Dict[Text, Any]] = None
) -> Dict[Text, Any]:
Expand All @@ -777,7 +845,9 @@ def _config(
}


class TestTEDPolicyWithRelativeAttentionMaxHistoryOne(TestTEDPolicy):
class TestTEDPolicyWithRelativeAttentionMaxHistoryOne(
TestTEDPolicyConfigurationOptions, TestTEDPolicy
):
max_history = 1

def _config(
Expand Down
21 changes: 6 additions & 15 deletions tests/core/policies/test_unexpected_intent_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,14 +635,11 @@ def test_skip_predictions_to_prevent_loop(
tmp_path: Path,
):
"""Skips predictions to prevent loop."""
loaded_policy = self.persist_and_load_policy(
trained_policy, model_storage, resource, execution_context
)
precomputations = None
tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)
tracker.update_with_events(tracker_events, default_domain)
with caplog.at_level(logging.DEBUG):
prediction = loaded_policy.predict_action_probabilities(
prediction = trained_policy.predict_action_probabilities(
tracker, default_domain, precomputations
)

Expand All @@ -651,7 +648,7 @@ def test_skip_predictions_to_prevent_loop(
) == should_skip

if should_skip:
assert prediction.probabilities == loaded_policy._default_predictions(
assert prediction.probabilities == trained_policy._default_predictions(
default_domain
)

Expand Down Expand Up @@ -691,20 +688,17 @@ def test_skip_predictions_if_new_intent(
tracker_events: List[Event],
):
"""Skips predictions if there's a new intent created."""
loaded_policy = self.persist_and_load_policy(
trained_policy, model_storage, resource, execution_context
)
tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots)
tracker.update_with_events(tracker_events, default_domain)

with caplog.at_level(logging.DEBUG):
prediction = loaded_policy.predict_action_probabilities(
prediction = trained_policy.predict_action_probabilities(
tracker, default_domain, precomputations=None
)

assert "Skipping predictions for UnexpecTEDIntentPolicy" in caplog.text

assert prediction.probabilities == loaded_policy._default_predictions(
assert prediction.probabilities == trained_policy._default_predictions(
default_domain
)

Expand Down Expand Up @@ -779,20 +773,17 @@ def test_ignore_action_unlikely_intent(
tracker_events_without_action: List[Event],
tmp_path: Path,
):
loaded_policy = self.persist_and_load_policy(
trained_policy, model_storage, resource, execution_context
)
precomputations = None
tracker_with_action = DialogueStateTracker.from_events(
"test 1", evts=tracker_events_with_action
)
tracker_without_action = DialogueStateTracker.from_events(
"test 2", evts=tracker_events_without_action
)
prediction_with_action = loaded_policy.predict_action_probabilities(
prediction_with_action = trained_policy.predict_action_probabilities(
tracker_with_action, default_domain, precomputations
)
prediction_without_action = loaded_policy.predict_action_probabilities(
prediction_without_action = trained_policy.predict_action_probabilities(
tracker_without_action, default_domain, precomputations
)

Expand Down

0 comments on commit fb1a806

Please sign in to comment.