diff --git a/tests/core/policies/test_ted_policy.py b/tests/core/policies/test_ted_policy.py index a31b2d8f9c2a..d5f4c016a97c 100644 --- a/tests/core/policies/test_ted_policy.py +++ b/tests/core/policies/test_ted_policy.py @@ -578,7 +578,70 @@ def test_empty_featurizer_configs( assert isinstance(featurizer.state_featurizer, state_featurizer) -class TestTEDPolicyMargin(TestTEDPolicy): +class TestTEDPolicyConfigurationOptions: + """Helper class to skip redundant and long-running tests in subclasses.""" + + @pytest.mark.parametrize("should_finetune", [False]) + @pytest.mark.skip() + def test_persist_and_load( + self, + trained_policy: Policy, + default_domain: Domain, + should_finetune: bool, + stories_path: Text, + model_storage: ModelStorage, + resource: Resource, + execution_context: ExecutionContext, + ): + """This takes long and does not need to be tested for every config change.""" + pass + + @pytest.mark.skip() + def test_train_model_checkpointing( + self, tmp_path: Path, tmp_path_factory: TempPathFactory + ): + """This takes long and does not need to be tested for every config change.""" + pass + + @pytest.mark.skip() + def test_doesnt_checkpoint_with_no_checkpointing( + self, tmp_path: Path, tmp_path_factory: TempPathFactory + ): + """This takes long and does not need to be tested for every config change.""" + pass + + @pytest.mark.skip() + def test_doesnt_checkpoint_with_zero_eval_num_examples( + self, tmp_path: Path, tmp_path_factory: TempPathFactory + ): + """This takes long and does not need to be tested for every config change.""" + + @pytest.mark.parametrize( + "should_finetune, epoch_override, expected_epoch_value", + [ + ( + True, + TEDPolicy.get_default_config()[EPOCHS] + 1, + TEDPolicy.get_default_config()[EPOCHS] + 1, + ) + ], + ) + @pytest.mark.skip() + def test_epoch_override_when_loaded( + self, + trained_policy: TEDPolicy, + should_finetune: bool, + epoch_override: int, + expected_epoch_value: int, + resource: Resource, + model_storage: ModelStorage, + execution_context: ExecutionContext, + ): + """This takes long and does not need to be tested for every config change.""" + pass + + +class TestTEDPolicyMargin(TestTEDPolicyConfigurationOptions, TestTEDPolicy): def _config( self, config_override: Optional[Dict[Text, Any]] = None ) -> Dict[Text, Any]: @@ -586,6 +649,7 @@ def _config( return { **TEDPolicy.get_default_config(), LOSS_TYPE: "margin", + EPOCHS: 2, **config_override, } @@ -619,7 +683,7 @@ def test_prediction_on_empty_tracker( assert min(prediction.probabilities) >= -1.0 -class TestTEDPolicyWithEval(TestTEDPolicy): +class TestTEDPolicyWithEval(TestTEDPolicyConfigurationOptions, TestTEDPolicy): def _config( self, config_override: Optional[Dict[Text, Any]] = None ) -> Dict[Text, Any]: @@ -632,7 +696,7 @@ def _config( } -class TestTEDPolicyNormalization(TestTEDPolicy): +class TestTEDPolicyNormalization(TestTEDPolicyConfigurationOptions, TestTEDPolicy): def _config( self, config_override: Optional[Dict[Text, Any]] = None ) -> Dict[Text, Any]: @@ -662,7 +726,7 @@ def test_ranking_length_and_renormalization( assert sum(predicted_probabilities) == pytest.approx(1) -class TestTEDPolicyLowRankingLength(TestTEDPolicy): +class TestTEDPolicyLowRankingLength(TestTEDPolicyConfigurationOptions, TestTEDPolicy): def _config( self, config_override: Optional[Dict[Text, Any]] = None ) -> Dict[Text, Any]: @@ -673,7 +737,7 @@ def test_ranking_length(self, trained_policy: TEDPolicy): assert trained_policy.config[RANKING_LENGTH] == 3 -class TestTEDPolicyHighRankingLength(TestTEDPolicy): +class TestTEDPolicyHighRankingLength(TestTEDPolicyConfigurationOptions, TestTEDPolicy): def _config( self, config_override: Optional[Dict[Text, Any]] = None ) -> Dict[Text, Any]: @@ -684,7 +748,9 @@ def test_ranking_length(self, trained_policy: TEDPolicy): assert trained_policy.config[RANKING_LENGTH] == 11 -class TestTEDPolicyWithStandardFeaturizer(TestTEDPolicy): +class TestTEDPolicyWithStandardFeaturizer( + TestTEDPolicyConfigurationOptions, TestTEDPolicy +): def _config( self, config_override: Optional[Dict[Text, Any]] = None ) -> Dict[Text, Any]: @@ -733,7 +799,7 @@ def test_featurizer( assert isinstance(loaded.featurizer.state_featurizer, SingleStateFeaturizer) -class TestTEDPolicyWithMaxHistory(TestTEDPolicy): +class TestTEDPolicyWithMaxHistory(TestTEDPolicyConfigurationOptions, TestTEDPolicy): def _config( self, config_override: Optional[Dict[Text, Any]] = None ) -> Dict[Text, Any]: @@ -763,7 +829,9 @@ def create_policy( ) -class TestTEDPolicyWithRelativeAttention(TestTEDPolicy): +class TestTEDPolicyWithRelativeAttention( + TestTEDPolicyConfigurationOptions, TestTEDPolicy +): def _config( self, config_override: Optional[Dict[Text, Any]] = None ) -> Dict[Text, Any]: @@ -777,7 +845,9 @@ def _config( } -class TestTEDPolicyWithRelativeAttentionMaxHistoryOne(TestTEDPolicy): +class TestTEDPolicyWithRelativeAttentionMaxHistoryOne( + TestTEDPolicyConfigurationOptions, TestTEDPolicy +): max_history = 1 def _config( diff --git a/tests/core/policies/test_unexpected_intent_policy.py b/tests/core/policies/test_unexpected_intent_policy.py index 08cca3564317..c7e6d0c572c8 100644 --- a/tests/core/policies/test_unexpected_intent_policy.py +++ b/tests/core/policies/test_unexpected_intent_policy.py @@ -635,14 +635,11 @@ def test_skip_predictions_to_prevent_loop( tmp_path: Path, ): """Skips predictions to prevent loop.""" - loaded_policy = self.persist_and_load_policy( - trained_policy, model_storage, resource, execution_context - ) precomputations = None tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots) tracker.update_with_events(tracker_events, default_domain) with caplog.at_level(logging.DEBUG): - prediction = loaded_policy.predict_action_probabilities( + prediction = trained_policy.predict_action_probabilities( tracker, default_domain, precomputations ) @@ -651,7 +648,7 @@ def test_skip_predictions_to_prevent_loop( ) == should_skip if should_skip: - assert prediction.probabilities == loaded_policy._default_predictions( + assert prediction.probabilities == trained_policy._default_predictions( default_domain ) @@ -691,20 +688,17 @@ def test_skip_predictions_if_new_intent( tracker_events: List[Event], ): """Skips predictions if there's a new intent created.""" - loaded_policy = self.persist_and_load_policy( - trained_policy, model_storage, resource, execution_context - ) tracker = DialogueStateTracker(sender_id="init", slots=default_domain.slots) tracker.update_with_events(tracker_events, default_domain) with caplog.at_level(logging.DEBUG): - prediction = loaded_policy.predict_action_probabilities( + prediction = trained_policy.predict_action_probabilities( tracker, default_domain, precomputations=None ) assert "Skipping predictions for UnexpecTEDIntentPolicy" in caplog.text - assert prediction.probabilities == loaded_policy._default_predictions( + assert prediction.probabilities == trained_policy._default_predictions( default_domain ) @@ -779,9 +773,6 @@ def test_ignore_action_unlikely_intent( tracker_events_without_action: List[Event], tmp_path: Path, ): - loaded_policy = self.persist_and_load_policy( - trained_policy, model_storage, resource, execution_context - ) precomputations = None tracker_with_action = DialogueStateTracker.from_events( "test 1", evts=tracker_events_with_action @@ -789,10 +780,10 @@ def test_ignore_action_unlikely_intent( tracker_without_action = DialogueStateTracker.from_events( "test 2", evts=tracker_events_without_action ) - prediction_with_action = loaded_policy.predict_action_probabilities( + prediction_with_action = trained_policy.predict_action_probabilities( tracker_with_action, default_domain, precomputations ) - prediction_without_action = loaded_policy.predict_action_probabilities( + prediction_without_action = trained_policy.predict_action_probabilities( tracker_without_action, default_domain, precomputations )