-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Markers/wrap up #10062
Markers/wrap up #10062
Changes from 3 commits
9889224
d01ff03
f33f65c
b8346d0
fadaac4
6435800
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,6 @@ | ||
from __future__ import annotations | ||
import os | ||
from abc import ABC, abstractmethod | ||
from rasa.shared.core.domain import Domain | ||
from rasa.shared.core.trackers import DialogueStateTracker | ||
from rasa.utils.io import WriteRow | ||
from typing import ( | ||
Dict, | ||
Iterator, | ||
|
@@ -29,6 +26,9 @@ | |
from rasa.shared.data import is_likely_yaml_file | ||
from rasa.shared.exceptions import InvalidConfigException, RasaException | ||
from rasa.shared.core.events import ActionExecuted, UserUttered, Event | ||
from rasa.shared.core.domain import Domain | ||
from rasa.shared.core.trackers import DialogueStateTracker | ||
from rasa.utils.io import WriteRow | ||
|
||
import logging | ||
import csv | ||
|
@@ -122,17 +122,6 @@ def _register_tag_class( | |
cls.marker_class_to_tag[marker_class] = positive_tag | ||
|
||
|
||
# We allow multiple atomic markers to be grouped under the same tag e.g. | ||
# 'slot_set: ["slot_a", "slot_b"]' (see `AtomicMarkers` / `CompoundMarkers`), | ||
# which is why this config maps to a list of texts or just one text: | ||
ConditionConfigList = Dict[Text, Union[Text, List[Text]]] | ||
# Compound markers can be nested: | ||
OperatorConfig = Dict[Text, List[Union["OperatorConfig", ConditionConfigList]]] | ||
# In case no compound operator is defined, "and" is used by default. Hence, | ||
# a marker config can also just consist of the config for a condition: | ||
MarkerConfig = Union[ConditionConfigList, OperatorConfig] | ||
|
||
|
||
class InvalidMarkerConfig(RasaException): | ||
"""Exception that can be raised when the config for a marker is not valid.""" | ||
|
||
|
@@ -145,6 +134,11 @@ class EventMetaData: | |
preceding_user_turns: int | ||
|
||
|
||
# We evaluate markers separately against every session and extract, for every marker | ||
# that we want to evaluate, the meta data of the respective relevant events where the | ||
# marker applies. | ||
SessionEvaluation = Dict[Text, List[EventMetaData]] | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
|
@@ -279,7 +273,7 @@ def validate_against_domain(self, domain: Domain) -> bool: | |
|
||
def evaluate_events( | ||
self, events: List[Event], recursive: bool = False | ||
) -> List[Dict[Text, List[EventMetaData]]]: | ||
) -> List[SessionEvaluation]: | ||
"""Resets the marker, tracks all events, and collects some information. | ||
|
||
The collected information includes: | ||
|
@@ -578,44 +572,70 @@ def from_config(config: Any, name: Optional[Text] = None) -> Marker: | |
|
||
return marker | ||
|
||
def export_markers( | ||
def evaluate_trackers( | ||
self, | ||
tracker_loader: Iterator[Optional[DialogueStateTracker]], | ||
output_file: Text, | ||
stats_file: Optional[Text] = None, | ||
trackers: Iterator[Optional[DialogueStateTracker]], | ||
output_file: Path, | ||
session_stats_file: Optional[Path] = None, | ||
overall_stats_file: Optional[Path] = None, | ||
) -> None: | ||
"""Collect markers for each dialogue in each tracker loaded. | ||
|
||
Args: | ||
tracker_loader: The tracker loader to use to select trackers for marker | ||
extraction. | ||
trackers: An iterator over the trackers from which we want to extract | ||
markers. | ||
output_file: Path to write out the extracted markers. | ||
stats_file: (Optional) Path to write out statistics about the extracted | ||
markers. | ||
""" | ||
processed_trackers = {} | ||
session_stats_file: (Optional) Path to write out statistics about the | ||
extracted markers for each session separately. | ||
overall_stats_file: (Optional) Path to write out statistics about the | ||
markers extracted from all session data. | ||
|
||
for tracker in tracker_loader: | ||
Raises: | ||
`FileExistsError` if any of the specified files already exists | ||
`NotADirectoryError` if any of the specified files is supposed to be | ||
contained in a directory that does not exist | ||
""" | ||
# Check files and folders before doing the costly swipe over the trackers: | ||
for path in [session_stats_file, overall_stats_file, output_file]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I put this in the CLI because I figured it would let us have a flag to allow the user to force an overwrite, do we want to bother adding in that feature right now or put it on the list of future enhancements? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the idea :D (had also put an overwrite parameter to the stats :D) - But since it wasn't in the list of requirements, let's just ignore it for now |
||
if path is not None and path.is_file(): | ||
raise FileExistsError(f"Expected that no file {path} already exists.") | ||
if path is not None and not path.parent.is_dir(): | ||
raise NotADirectoryError(f"Expected directory {path.parent} to exist.") | ||
|
||
# Apply marker to each session stored in each tracker and save the results. | ||
processed_trackers: Dict[Text, List[SessionEvaluation]] = {} | ||
for tracker in trackers: | ||
if tracker: | ||
tracker_result = self.evaluate_events(tracker.events) | ||
processed_trackers[tracker.sender_id] = tracker_result | ||
|
||
Marker._save_results(output_file, processed_trackers) | ||
|
||
if stats_file: | ||
Marker._compute_stats(stats_file, processed_trackers) | ||
# Compute and write statistics if requested. | ||
if session_stats_file or overall_stats_file: | ||
from rasa.core.evaluation.marker_stats import MarkerStatistics | ||
|
||
stats = MarkerStatistics() | ||
for sender_id, tracker_result in processed_trackers.items(): | ||
for session_idx, session_result in enumerate(tracker_result): | ||
stats.process( | ||
sender_id=sender_id, | ||
session_idx=session_idx, | ||
meta_data_on_relevant_events_per_marker=session_result, | ||
) | ||
if overall_stats_file: | ||
stats.overall_statistic_to_csv(path=overall_stats_file) | ||
if session_stats_file: | ||
stats.per_session_statistics_to_csv(path=session_stats_file) | ||
|
||
@staticmethod | ||
def _save_results( | ||
path: Text, results: Dict[Text, List[Dict[Text, EventMetaData]]] | ||
) -> None: | ||
def _save_results(path: Path, results: Dict[Text, List[SessionEvaluation]]) -> None: | ||
"""Save extracted marker results as CSV to specified path. | ||
|
||
Args: | ||
path: Path to write out the extracted markers. | ||
results: Extracted markers from a selection of trackers. | ||
""" | ||
with open(path, "w") as f: | ||
with path.open(mode="w") as f: | ||
table_writer = csv.writer(f) | ||
table_writer.writerow( | ||
[ | ||
|
@@ -626,39 +646,28 @@ def _save_results( | |
"num_preceding_user_turns", | ||
] | ||
) | ||
for sender_id, dialogues in results.items(): | ||
for session_idx, session in enumerate(dialogues): | ||
for sender_id, results_per_session in results.items(): | ||
for session_idx, session_result in enumerate(results_per_session): | ||
Marker._write_relevant_events( | ||
table_writer, sender_id, session_idx, session | ||
table_writer, sender_id, session_idx, session_result | ||
) | ||
|
||
@staticmethod | ||
def _write_relevant_events( | ||
writer: WriteRow, | ||
sender_id: Text, | ||
session_idx: int, | ||
session: Dict[Text, EventMetaData], | ||
writer: WriteRow, sender_id: Text, session_idx: int, session: SessionEvaluation, | ||
) -> None: | ||
for marker_name, marker_metadata in session.items(): | ||
for metadata in marker_metadata: | ||
for marker_name, meta_data_per_relevant_event in session.items(): | ||
for event_meta_data in meta_data_per_relevant_event: | ||
writer.writerow( | ||
[ | ||
sender_id, | ||
str(session_idx), | ||
marker_name, | ||
metadata.idx, | ||
metadata.preceding_user_turns, | ||
str(event_meta_data.idx), | ||
str(event_meta_data.preceding_user_turns), | ||
] | ||
) | ||
|
||
@staticmethod | ||
def _compute_stats( | ||
out_file: Text, results: List[Union[Text, Dict[Text, EventMetaData]]] | ||
) -> None: | ||
"""Compute stats over extracted marker data.""" | ||
# TODO: Figure out how this is done | ||
pass | ||
|
||
|
||
class OperatorMarker(Marker, ABC): | ||
"""Combines several markers into one.""" | ||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat 😎