Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable union-attr mypy check and fix issues #10942

Merged
merged 23 commits into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixes for failing tests
  • Loading branch information
ancalita committed Feb 25, 2022
commit 9006a8cf638f93b96e2b5bad729fe7e146d391af
8 changes: 3 additions & 5 deletions rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,9 @@ def _collect_sources(
# in order to correctly output the names of the contradicting rules
rule_name = tracker.sender_id

if prediction_source is None:
return None

if prediction_source.startswith(DEFAULT_RULES) or prediction_source.startswith(
LOOP_RULES
if isinstance(prediction_source, str) and (
ancalita marked this conversation as resolved.
Show resolved Hide resolved
prediction_source.startswith(DEFAULT_RULES)
or prediction_source.startswith(LOOP_RULES)
):
# the real gold action contradict the one in the rules in this case
gold_action_name = predicted_action_name
Expand Down
12 changes: 5 additions & 7 deletions rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings as pywarnings
import typing
from collections import defaultdict, namedtuple
from typing import Any, Dict, List, Optional, Text, Tuple, cast
from typing import Any, Dict, List, Optional, Text, Tuple

from rasa import telemetry
from rasa.core.constants import (
Expand Down Expand Up @@ -434,7 +434,8 @@ def _create_data_generator(
from rasa.shared.core.generator import TrainingDataGenerator

tmp_domain_path = Path(tempfile.mkdtemp()) / "domain.yaml"
cast(Domain, agent.domain).persist(tmp_domain_path)
domain = agent.domain if isinstance(agent.domain, Domain) else Domain.empty()
ancalita marked this conversation as resolved.
Show resolved Hide resolved
domain.persist(tmp_domain_path)
test_data_importer = TrainingDataImporter.load_from_dict(
training_data_paths=[resource_name], domain_path=str(tmp_domain_path)
)
Expand Down Expand Up @@ -822,15 +823,12 @@ async def _predict_tracker_actions(
List[EntityEvaluationResult],
]:

processor = cast(MessageProcessor, agent.processor)
processor = agent.processor
tracker_eval_store = EvaluationStore()

events = list(tracker.events)

if not isinstance(agent.domain, Domain):
slots = []
else:
slots = agent.domain.slots
slots = agent.domain.slots if isinstance(agent.domain, Domain) else []
ancalita marked this conversation as resolved.
Show resolved Hide resolved

partial_tracker = DialogueStateTracker.from_events(
tracker.sender_id,
Expand Down
7 changes: 2 additions & 5 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ async def _predict_till_next_listen(
if result is None:
result = {}

predictions = result.get("scores") or []
predictions = result.get("scores", [])
if not predictions:
raise InvalidConfigException(
"Cannot continue as no action was predicted by the dialogue manager. "
Expand Down Expand Up @@ -1476,10 +1476,7 @@ async def record_messages(
)
return

if domain is None:
domain_intents = []
else:
domain_intents = domain.get("intents")
domain_intents = domain.get("intents", []) if isinstance(domain, dict) else []
ancalita marked this conversation as resolved.
Show resolved Hide resolved

intents = [next(iter(i)) for i in domain_intents]

Expand Down
11 changes: 7 additions & 4 deletions rasa/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,10 @@ def verify_utterances_in_stories(self, ignore_warnings: bool = True) -> bool:
if not isinstance(event, ActionExecuted):
continue

action_name = event.action_name if event.action_name is not None else ""
if not action_name.startswith(UTTER_PREFIX):
if not isinstance(event.action_name, str):
ancalita marked this conversation as resolved.
Show resolved Hide resolved
continue

if not event.action_name.startswith(UTTER_PREFIX):
# we are only interested in utter actions
continue

Expand Down Expand Up @@ -261,9 +263,10 @@ def verify_actions_in_stories_rules(self) -> bool:
if not isinstance(event, ActionExecuted):
continue

action_name = event.action_name if event.action_name is not None else ""
if not isinstance(event.action_name, str):
ancalita marked this conversation as resolved.
Show resolved Hide resolved
continue

if action_name.startswith("action_"):
if not event.action_name.startswith("action_"):
continue

if event.action_name in visited:
Expand Down