Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overlapping entities #1604

Merged
merged 18 commits into from
Jan 21, 2019
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions rasa_nlu/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from rasa_nlu import training_data, utils, config
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.extractors.crf_entity_extractor import CRFEntityExtractor
from rasa_nlu.model import Interpreter
from rasa_nlu.model import Trainer, TrainingData

Expand All @@ -41,8 +42,8 @@
def create_argument_parser():
import argparse
parser = argparse.ArgumentParser(
description='evaluate a Rasa NLU pipeline with cross '
'validation or on external data')
description='evaluate a Rasa NLU pipeline with cross '
'validation or on external data')

parser.add_argument('-d', '--data', required=True,
help="file containing training/evaluation data")
Expand Down Expand Up @@ -204,7 +205,7 @@ def drop_intents_below_freq(td, cutoff=5):
"""Remove intent groups with less than cutoff instances."""

logger.debug(
"Raw data intent examples: {}".format(len(td.intent_examples)))
"Raw data intent examples: {}".format(len(td.intent_examples)))
keep_examples = [ex
for ex in td.intent_examples
if td.examples_per_intent[ex.get("intent")] >= cutoff]
Expand Down Expand Up @@ -369,10 +370,10 @@ def evaluate_entities(targets,
for extractor in extractors:
merged_predictions = merge_labels(aligned_predictions, extractor)
merged_predictions = substitute_labels(
merged_predictions, "O", "no_entity")
merged_predictions, "O", "no_entity")
logger.info("Evaluation for entity extractor: {} ".format(extractor))
report, precision, f1, accuracy = get_evaluation_metrics(
merged_targets, merged_predictions)
merged_targets, merged_predictions)
log_evaluation_table(report, precision, f1, accuracy)
result[extractor] = {
"report": report,
Expand Down Expand Up @@ -462,24 +463,33 @@ def pick_best_entity_fit(token, candidates):
return candidates[best_fit]["entity"]


def determine_token_labels(token, entities):
def determine_token_labels(token, entities, extractors):
"""Determines the token label given entities that do not overlap.

:param token: a single token
:param entities: entities found by a single extractor
:return: entity type
Args:
token: a single token
entities: entities found by a single extractor
extractors: list of extractors
Returns:
entity type
"""

if len(entities) == 0:
return "O"

if do_entities_overlap(entities):
if not do_extractors_support_overlap(extractors) and \
do_entities_overlap(entities):
raise ValueError("The possible entities should not overlap")

candidates = find_intersecting_entites(token, entities)
return pick_best_entity_fit(token, candidates)


def do_extractors_support_overlap(extractors):
"""Checks if extractors support overlapping entities
"""
return extractors is None or not (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be extractors is None or CRFEntityExtractor.name not in extractors

CRFEntityExtractor.name in extractors)


def align_entity_predictions(targets, predictions, tokens, extractors):
"""Aligns entity predictions to the message tokens.

Expand All @@ -501,9 +511,10 @@ def align_entity_predictions(targets, predictions, tokens, extractors):
entities_by_extractors[p["extractor"]].append(p)
extractor_labels = defaultdict(list)
for t in tokens:
true_token_labels.append(determine_token_labels(t, targets))
true_token_labels.append(
determine_token_labels(t, targets, extractors))
for extractor, entities in entities_by_extractors.items():
extracted = determine_token_labels(t, entities)
extracted = determine_token_labels(t, entities, extractor)
extractor_labels[extractor].append(extracted)

return {"target_labels": true_token_labels,
Expand Down Expand Up @@ -569,10 +580,10 @@ def get_intent_predictions(targets, interpreter,
for e, target in zip(test_data.training_examples, targets):
res = interpreter.parse(e.text, only_output_properties=False)
intent_results.append(IntentEvaluationResult(
target,
extract_intent(res),
extract_message(res),
extract_confidence(res)))
target,
extract_intent(res),
extract_message(res),
extract_confidence(res)))

return intent_results

Expand Down Expand Up @@ -692,7 +703,7 @@ def run_evaluation(data_path, model,
if is_intent_classifier_present(interpreter):
intent_targets = get_intent_targets(test_data)
intent_results = get_intent_predictions(
intent_targets, interpreter, test_data)
intent_targets, interpreter, test_data)

logger.info("Intent evaluation results:")
result['intent_evaluation'] = evaluate_intents(intent_results,
Expand Down Expand Up @@ -894,7 +905,7 @@ def main():
data = training_data.load_data(cmdline_args.data)
data = drop_intents_below_freq(data, cutoff=5)
results, entity_results = run_cv_evaluation(
data, int(cmdline_args.folds), nlu_config)
data, int(cmdline_args.folds), nlu_config)
logger.info("CV evaluation (n={})".format(cmdline_args.folds))

if any(results):
Expand Down
27 changes: 23 additions & 4 deletions tests/base/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from rasa_nlu.evaluate import does_token_cross_borders
from rasa_nlu.evaluate import align_entity_predictions
from rasa_nlu.evaluate import determine_intersection
from rasa_nlu.evaluate import determine_token_labels
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.tokenizers import Token
from rasa_nlu import training_data, config
Expand All @@ -30,10 +31,10 @@
def duckling_interpreter(component_builder, tmpdir_factory):
conf = RasaNLUModelConfig({"pipeline": [{"name": "ner_duckling"}]})
return utilities.interpreter_for(
component_builder,
data="./data/examples/rasa/demo-rasa.json",
path=tmpdir_factory.mktemp("projects").strpath,
config=conf)
component_builder,
data="./data/examples/rasa/demo-rasa.json",
path=tmpdir_factory.mktemp("projects").strpath,
config=conf)


# Chinese Example
Expand Down Expand Up @@ -165,6 +166,23 @@ def test_entity_overlap():
assert not do_entities_overlap(EN_targets)


def test_determine_token_labels_throws_error():
with pytest.raises(ValueError):
determine_token_labels(CH_correct_segmentation,
[CH_correct_entity,
CH_wrong_entity], ["ner_crf"])


def test_determine_token_labels_no_extractors():
determine_token_labels(CH_correct_segmentation[0],
[CH_correct_entity, CH_wrong_entity], None)


def test_determine_token_labels_with_extractors():
determine_token_labels(CH_correct_segmentation[0],
[CH_correct_entity, CH_wrong_entity], ["A", "B"])


def test_label_merging():
aligned_predictions = [
{"target_labels": ["O", "O"], "extractor_labels":
Expand Down Expand Up @@ -260,6 +278,7 @@ def test_evaluate_entities():
mock_extractors = ["A", "B"]
result = align_entity_predictions(EN_targets, EN_predicted,
EN_tokens, mock_extractors)

assert result == {
"target_labels": ["O", "O", "O", "O", "O", "O", "O", "O", "food",
"location", "location", "datetime"],
Expand Down