Skip to content

Commit

Permalink
Merge pull request #773 from snipsco/fix/ambiguous-intents
Browse files Browse the repository at this point in the history
Improve handling of ambiguous utterances in DeterministicIntentParser
  • Loading branch information
adrienball authored Mar 25, 2019
2 parents 0ddd8c1 + f26f53e commit 045d57e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 15 deletions.
17 changes: 14 additions & 3 deletions snips_nlu/intent_parser/deterministic_intent_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from snips_nlu.constants import (
DATA, END, ENTITIES, ENTITY,
INTENTS, LANGUAGE, RES_INTENT, RES_INTENT_NAME,
RES_MATCH_RANGE, RES_SLOTS, RES_VALUE, SLOT_NAME, START, TEXT, UTTERANCES)
RES_MATCH_RANGE, RES_SLOTS, RES_VALUE, SLOT_NAME, START, TEXT, UTTERANCES,
RES_PROBA)
from snips_nlu.dataset import validate_and_format_dataset
from snips_nlu.entity_parser.builtin_entity_parser import is_builtin_entity
from snips_nlu.exceptions import IntentNotFoundError, LoadingError
Expand Down Expand Up @@ -198,6 +199,9 @@ def parse(self, text, intents=None, top_n=None):
if top_intents:
intent = top_intents[0][RES_INTENT]
slots = top_intents[0][RES_SLOTS]
if intent[RES_PROBA] < 1.0:
# return None in case of ambiguity
return empty_result(text, probability=1.0)
return parsing_result(text, intent, slots)
return empty_result(text, probability=1.0)
return self._parse_top_intents(text, top_n=top_n, intents=intents)
Expand Down Expand Up @@ -239,8 +243,15 @@ def placeholder_fn(entity_name):
if res is not None:
results.append(res)
break
if len(results) == top_n:
return results

confidence_score = 1.
if results:
confidence_score = 1. / float(len(results))

results = results[:top_n]

for res in results:
res[RES_INTENT][RES_PROBA] = confidence_score

return results

Expand Down
58 changes: 46 additions & 12 deletions snips_nlu/tests/test_deterministic_intent_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,34 @@ def test_should_parse_top_intents(self):
type: intent
name: intent1
utterances:
- hello world
- meeting tomorrow
---
type: intent
name: intent2
utterances:
- foo bar""")
- meeting [time:snips/datetime](today)""")
dataset = Dataset.from_yaml_files("en", [dataset_stream]).json
parser = DeterministicIntentParser().fit(dataset)
text = "hello world"
text = "meeting tomorrow"

# When
results = parser.parse(text, top_n=3)

# Then
expected_intent = intent_classification_result(
intent_name="intent1", probability=1.0)
expected_results = [extraction_result(expected_intent, [])]
slot = {
"entity": "snips/datetime",
"range": {"end": 16, "start": 8},
"slotName": "time",
"value": "tomorrow"
}
expected_results = [
extraction_result(intent_classification_result(
intent_name="intent1", probability=0.5), []),
extraction_result(intent_classification_result(
intent_name="intent2", probability=0.5), [slot])
]
results = sorted(results, key=lambda r: r[RES_INTENT][RES_INTENT_NAME])
self.assertEqual(expected_results, results)

@patch("snips_nlu.intent_parser.deterministic_intent_parser"
Expand Down Expand Up @@ -225,6 +235,30 @@ def test_should_ignore_ambiguous_utterances(self):
# Then
self.assertEqual(empty_result(text, 1.0), res)

def test_should_ignore_subtly_ambiguous_utterances(self):
# Given
dataset_stream = io.StringIO("""
---
type: intent
name: intent_1
utterances:
- meeting tomorrow
---
type: intent
name: intent_2
utterances:
- meeting [time:snips/datetime](today)""")
dataset = Dataset.from_yaml_files("en", [dataset_stream]).json
parser = DeterministicIntentParser().fit(dataset)
text = "meeting tomorrow"

# When
res = parser.parse(text)

# Then
self.assertEqual(empty_result(text, 1.0), res)

def test_should_not_parse_when_not_fitted(self):
# Given
parser = DeterministicIntentParser()
Expand Down Expand Up @@ -565,7 +599,7 @@ def test_should_parse_naughty_strings(self):
- this is [slot2:entity2](second_entity)""")
dataset = Dataset.from_yaml_files("en", [dataset_stream]).json
naughty_strings_path = TEST_PATH / "resources" / "naughty_strings.txt"
with naughty_strings_path.open(encoding='utf8') as f:
with naughty_strings_path.open(encoding="utf8") as f:
naughty_strings = [line.strip("\n") for line in f.readlines()]

# When
Expand All @@ -579,7 +613,7 @@ def test_should_parse_naughty_strings(self):
def test_should_fit_with_naughty_strings_no_tags(self):
# Given
naughty_strings_path = TEST_PATH / "resources" / "naughty_strings.txt"
with naughty_strings_path.open(encoding='utf8') as f:
with naughty_strings_path.open(encoding="utf8") as f:
naughty_strings = [line.strip("\n") for line in f.readlines()]

utterances = [{DATA: [{TEXT: naughty_string}]} for naughty_string in
Expand Down Expand Up @@ -635,13 +669,13 @@ def test_should_fit_and_parse_with_non_ascii_tags(self):
parsing = parser.parse("string0")

expected_slot = {
'entity': 'non_ascìi_entïty',
'range': {
"entity": "non_ascìi_entïty",
"range": {
"start": 0,
"end": 7
},
'slotName': u'non_ascìi_slöt',
'value': u'string0'
"slotName": u"non_ascìi_slöt",
"value": u"string0"
}
intent_name = parsing[RES_INTENT][RES_INTENT_NAME]
self.assertEqual("naughty_intent", intent_name)
Expand Down

0 comments on commit 045d57e

Please sign in to comment.