Skip to content

Commit

Permalink
Markers/wrap up (RasaHQ#10062)
Browse files Browse the repository at this point in the history
* connect stats and cli; rm unused types from marker.py; ...

* rm unused schemata

* try to improve docstr/help

* Fix stats CLI command inversion

* fix help test (nargs=? -> stats file prefix is optional)

Co-authored-by: Matthew Summers <m.summers@rasa.com>
  • Loading branch information
ka-bu and usc-m authored Nov 2, 2021
1 parent 9142bd4 commit fd465bb
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 146 deletions.
13 changes: 8 additions & 5 deletions rasa/cli/arguments/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,20 @@ def set_markers_arguments(parser: argparse.ArgumentParser) -> None:

stats.add_argument(
"--no-stats",
default=False,
action="store_true",
action="store_false",
dest="stats",
help="Do not compute summary statistics.",
)

stats.add_argument(
"--stats-file",
default="stats.csv",
"--stats-file-prefix",
default="stats",
nargs="?",
type=str,
help="The filename to write out computed summary statistics.",
help="The common file prefix of the files where we write out the compute "
"statistics. More precisely, the file prefix must consist of a common "
"path plus a common file prefix, to which suffixes `-overall.csv` and "
"`-per-session.csv` will be added automatically.",
)

add_endpoint_param(
Expand Down
49 changes: 29 additions & 20 deletions rasa/cli/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import argparse
from rasa.shared.core.domain import Domain
from typing import List, Text, Optional
from pathlib import Path

from rasa.core.utils import AvailableEndpoints
from rasa.core.tracker_store import TrackerStore
from rasa.core.evaluation.marker_tracker_loader import MarkerTrackerLoader
from rasa.core.evaluation.marker_base import Marker

from rasa.shared.core.domain import Domain
from rasa.cli import SubParsersAction
import rasa.cli.arguments.evaluate as arguments
import rasa.shared.utils.cli
import os.path

STATS_OVERALL_SUFFIX = "-overall.csv"
STATS_SESSION_SUFFIX = "-per-session.csv"


def add_subparser(
Expand Down Expand Up @@ -82,7 +84,7 @@ def _run_markers_cli(args: argparse.Namespace) -> None:
seed = args.seed if "seed" in args else None
count = args.count if "count" in args else None

stats_file = args.stats_file if "stats_file" in args and args.stats else None
stats_file_prefix = args.stats_file_prefix if args.stats else None

_run_markers(
seed,
Expand All @@ -92,7 +94,7 @@ def _run_markers_cli(args: argparse.Namespace) -> None:
args.strategy,
args.config,
args.output_filename,
stats_file,
stats_file_prefix,
)


Expand All @@ -104,7 +106,7 @@ def _run_markers(
strategy: Text,
config: Text,
output_filename: Text,
stats_file: Optional[Text] = None,
stats_file_prefix: Optional[Path] = None,
) -> None:
"""Run markers algorithm over specified config and tracker store.
Expand All @@ -120,30 +122,37 @@ def _run_markers(
strategy: Strategy to use when selecting trackers to extract from.
config: Path to the markers definition file to use.
output_filename: Path to write out the extracted markers.
stats_file: (Optional) Path to write out statistics about the extracted
markers.
stats_file_prefix: (Optional) A prefix used to create paths where files with
statistics on the marker extraction results will be written.
It must consists of the path to the where those files should be stored
and the common file prefix, e.g. '<path-to-stats-folder>/statistics'.
Statistics derived from all marker extractions will be stored in
'<path-to-stats-folder>/statistics-overall.csv', while the statistics
computed per session will be stored in
'<path-to-stats-folder>/statistics-per-session.csv'.
"""
if os.path.exists(output_filename):
rasa.shared.utils.cli.print_error_and_exit(
"A file with the output filename already exists"
)

if stats_file and os.path.exists(stats_file):
rasa.shared.utils.cli.print_error_and_exit(
"A file with the stats filename already exists"
)

domain = Domain.load(domain_path) if domain_path else None
markers = Marker.from_path(config)

if domain and not markers.validate_against_domain(domain):
rasa.shared.utils.cli.print_error_and_exit(
"Validation errors were found in the markers definition. "
"Please see errors listed above and fix before running again."
)

tracker_loader = _create_tracker_loader(endpoint_config, strategy, count, seed)
markers.export_markers(tracker_loader.load(), output_filename, stats_file)

def _append_suffix(path: Optional[Path], suffix: Text) -> Optional[Path]:
return path.parent / (path.name + suffix) if path else None

try:
markers.evaluate_trackers(
trackers=tracker_loader.load(),
output_file=output_filename,
session_stats_file=_append_suffix(stats_file_prefix, STATS_SESSION_SUFFIX),
overall_stats_file=_append_suffix(stats_file_prefix, STATS_OVERALL_SUFFIX),
)
except (FileExistsError, NotADirectoryError) as e:
rasa.shared.utils.cli.print_error_and_exit(message=str(e))


def _create_tracker_loader(
Expand Down
113 changes: 61 additions & 52 deletions rasa/core/evaluation/marker_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from rasa.shared.core.domain import Domain
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.utils.io import WriteRow
from typing import (
Dict,
Iterator,
Expand All @@ -29,6 +26,9 @@
from rasa.shared.data import is_likely_yaml_file
from rasa.shared.exceptions import InvalidConfigException, RasaException
from rasa.shared.core.events import ActionExecuted, UserUttered, Event
from rasa.shared.core.domain import Domain
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.utils.io import WriteRow

import logging
import csv
Expand Down Expand Up @@ -122,17 +122,6 @@ def _register_tag_class(
cls.marker_class_to_tag[marker_class] = positive_tag


# We allow multiple atomic markers to be grouped under the same tag e.g.
# 'slot_set: ["slot_a", "slot_b"]' (see `AtomicMarkers` / `CompoundMarkers`),
# which is why this config maps to a list of texts or just one text:
ConditionConfigList = Dict[Text, Union[Text, List[Text]]]
# Compound markers can be nested:
OperatorConfig = Dict[Text, List[Union["OperatorConfig", ConditionConfigList]]]
# In case no compound operator is defined, "and" is used by default. Hence,
# a marker config can also just consist of the config for a condition:
MarkerConfig = Union[ConditionConfigList, OperatorConfig]


class InvalidMarkerConfig(RasaException):
"""Exception that can be raised when the config for a marker is not valid."""

Expand All @@ -145,6 +134,11 @@ class EventMetaData:
preceding_user_turns: int


# We evaluate markers separately against every session and extract, for every marker
# that we want to evaluate, the meta data of the respective relevant events where the
# marker applies.
SessionEvaluation = Dict[Text, List[EventMetaData]]

T = TypeVar("T")


Expand Down Expand Up @@ -279,7 +273,7 @@ def validate_against_domain(self, domain: Domain) -> bool:

def evaluate_events(
self, events: List[Event], recursive: bool = False
) -> List[Dict[Text, List[EventMetaData]]]:
) -> List[SessionEvaluation]:
"""Resets the marker, tracks all events, and collects some information.
The collected information includes:
Expand Down Expand Up @@ -578,44 +572,70 @@ def from_config(config: Any, name: Optional[Text] = None) -> Marker:

return marker

def export_markers(
def evaluate_trackers(
self,
tracker_loader: Iterator[Optional[DialogueStateTracker]],
output_file: Text,
stats_file: Optional[Text] = None,
trackers: Iterator[Optional[DialogueStateTracker]],
output_file: Path,
session_stats_file: Optional[Path] = None,
overall_stats_file: Optional[Path] = None,
) -> None:
"""Collect markers for each dialogue in each tracker loaded.
Args:
tracker_loader: The tracker loader to use to select trackers for marker
extraction.
trackers: An iterator over the trackers from which we want to extract
markers.
output_file: Path to write out the extracted markers.
stats_file: (Optional) Path to write out statistics about the extracted
markers.
"""
processed_trackers = {}
session_stats_file: (Optional) Path to write out statistics about the
extracted markers for each session separately.
overall_stats_file: (Optional) Path to write out statistics about the
markers extracted from all session data.
for tracker in tracker_loader:
Raises:
`FileExistsError` if any of the specified files already exists
`NotADirectoryError` if any of the specified files is supposed to be
contained in a directory that does not exist
"""
# Check files and folders before doing the costly swipe over the trackers:
for path in [session_stats_file, overall_stats_file, output_file]:
if path is not None and path.is_file():
raise FileExistsError(f"Expected that no file {path} already exists.")
if path is not None and not path.parent.is_dir():
raise NotADirectoryError(f"Expected directory {path.parent} to exist.")

# Apply marker to each session stored in each tracker and save the results.
processed_trackers: Dict[Text, List[SessionEvaluation]] = {}
for tracker in trackers:
if tracker:
tracker_result = self.evaluate_events(tracker.events)
processed_trackers[tracker.sender_id] = tracker_result

Marker._save_results(output_file, processed_trackers)

if stats_file:
Marker._compute_stats(stats_file, processed_trackers)
# Compute and write statistics if requested.
if session_stats_file or overall_stats_file:
from rasa.core.evaluation.marker_stats import MarkerStatistics

stats = MarkerStatistics()
for sender_id, tracker_result in processed_trackers.items():
for session_idx, session_result in enumerate(tracker_result):
stats.process(
sender_id=sender_id,
session_idx=session_idx,
meta_data_on_relevant_events_per_marker=session_result,
)
if overall_stats_file:
stats.overall_statistic_to_csv(path=overall_stats_file)
if session_stats_file:
stats.per_session_statistics_to_csv(path=session_stats_file)

@staticmethod
def _save_results(
path: Text, results: Dict[Text, List[Dict[Text, EventMetaData]]]
) -> None:
def _save_results(path: Path, results: Dict[Text, List[SessionEvaluation]]) -> None:
"""Save extracted marker results as CSV to specified path.
Args:
path: Path to write out the extracted markers.
results: Extracted markers from a selection of trackers.
"""
with open(path, "w") as f:
with path.open(mode="w") as f:
table_writer = csv.writer(f)
table_writer.writerow(
[
Expand All @@ -626,39 +646,28 @@ def _save_results(
"num_preceding_user_turns",
]
)
for sender_id, dialogues in results.items():
for session_idx, session in enumerate(dialogues):
for sender_id, results_per_session in results.items():
for session_idx, session_result in enumerate(results_per_session):
Marker._write_relevant_events(
table_writer, sender_id, session_idx, session
table_writer, sender_id, session_idx, session_result
)

@staticmethod
def _write_relevant_events(
writer: WriteRow,
sender_id: Text,
session_idx: int,
session: Dict[Text, EventMetaData],
writer: WriteRow, sender_id: Text, session_idx: int, session: SessionEvaluation,
) -> None:
for marker_name, marker_metadata in session.items():
for metadata in marker_metadata:
for marker_name, meta_data_per_relevant_event in session.items():
for event_meta_data in meta_data_per_relevant_event:
writer.writerow(
[
sender_id,
str(session_idx),
marker_name,
metadata.idx,
metadata.preceding_user_turns,
str(event_meta_data.idx),
str(event_meta_data.preceding_user_turns),
]
)

@staticmethod
def _compute_stats(
out_file: Text, results: List[Union[Text, Dict[Text, EventMetaData]]]
) -> None:
"""Compute stats over extracted marker data."""
# TODO: Figure out how this is done
pass


class OperatorMarker(Marker, ABC):
"""Combines several markers into one."""
Expand Down
58 changes: 0 additions & 58 deletions rasa/shared/utils/schemas/markers.py

This file was deleted.

Loading

0 comments on commit fd465bb

Please sign in to comment.