Skip to content

Commit

Permalink
Fix issue with stop words in DeterministicIntentParser (#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrienball authored Apr 26, 2019
1 parent 7fb7208 commit b8466e7
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.
## [Unreleased]
### Fixed
- Raise an error when using unknown intents in intents filter [#788](https://github.com/snipsco/snips-nlu/pull/788)
- Fix issue with stop words in `DeterministicIntentParser` [#789](https://github.com/snipsco/snips-nlu/pull/789)

## [0.19.5]
### Added
Expand Down
13 changes: 13 additions & 0 deletions snips_nlu/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import unicode_literals

from future.utils import iteritems, itervalues
from snips_nlu_utils import normalize
from yaml import Loader, SafeLoader

from snips_nlu.constants import (
Expand Down Expand Up @@ -41,6 +42,18 @@ def extract_intent_entities(dataset, entity_filter=None):
return intent_entities


def extract_entity_values(dataset, apply_normalization):
entities_per_intent = {intent: set() for intent in dataset[INTENTS]}
intent_entities = extract_intent_entities(dataset)
for intent, entities in iteritems(intent_entities):
for entity in entities:
entity_values = set(dataset[ENTITIES][entity][UTTERANCES])
if apply_normalization:
entity_values = {normalize(v) for v in entity_values}
entities_per_intent[intent].update(entity_values)
return entities_per_intent


def get_text_from_chunks(chunks):
return "".join(chunk[TEXT] for chunk in chunks)

Expand Down
64 changes: 50 additions & 14 deletions snips_nlu/intent_parser/deterministic_intent_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RES_MATCH_RANGE, RES_SLOTS, RES_VALUE, SLOT_NAME, START, TEXT, UTTERANCES,
RES_PROBA)
from snips_nlu.dataset import validate_and_format_dataset
from snips_nlu.dataset.utils import extract_entity_values
from snips_nlu.entity_parser.builtin_entity_parser import is_builtin_entity
from snips_nlu.exceptions import IntentNotFoundError, LoadingError
from snips_nlu.intent_parser.intent_parser import IntentParser
Expand Down Expand Up @@ -55,10 +56,11 @@ def __init__(self, config=None, **shared):
self._language = None
self._slot_names_to_entities = None
self._group_names_to_slot_names = None
self._stop_words = None
self._stop_words_whitelist = None
self.slot_names_to_group_names = None
self.regexes_per_intent = None
self.entity_scopes = None
self.stop_words = None

@property
def language(self):
Expand All @@ -68,12 +70,12 @@ def language(self):
def language(self, value):
self._language = value
if value is None:
self.stop_words = None
self._stop_words = None
else:
if self.config.ignore_stop_words:
self.stop_words = get_stop_words(self.resources)
self._stop_words = get_stop_words(self.resources)
else:
self.stop_words = set()
self._stop_words = set()

@property
def slot_names_to_entities(self):
Expand Down Expand Up @@ -142,13 +144,15 @@ def fit(self, dataset, force_retrain=True):
self.slot_names_to_entities = get_slot_name_mappings(dataset)
self.group_names_to_slot_names = _get_group_names_to_slot_names(
self.slot_names_to_entities)
self._stop_words_whitelist = _get_stop_words_whitelist(
dataset, self._stop_words)

# Do not use ambiguous patterns that appear in more than one intent
all_patterns = set()
ambiguous_patterns = set()
intent_patterns = dict()
for intent_name, intent in iteritems(dataset[INTENTS]):
patterns = self._generate_patterns(intent[UTTERANCES],
patterns = self._generate_patterns(intent_name, intent[UTTERANCES],
entity_placeholders)
patterns = [p for p in patterns
if len(p) < self.config.max_pattern_length]
Expand Down Expand Up @@ -221,7 +225,6 @@ def placeholder_fn(entity_name):
return _get_entity_name_placeholder(entity_name, self.language)

results = []
cleaned_text = self._preprocess_text(text)

for intent, entity_scope in iteritems(self.entity_scopes):
if intents is not None and intent not in intents:
Expand All @@ -233,7 +236,9 @@ def placeholder_fn(entity_name):
all_entities = builtin_entities + custom_entities
mapping, processed_text = replace_entities_with_placeholders(
text, all_entities, placeholder_fn=placeholder_fn)
cleaned_processed_text = self._preprocess_text(processed_text)
cleaned_text = self._preprocess_text(text, intent)
cleaned_processed_text = self._preprocess_text(processed_text,
intent)
for regex in self.regexes_per_intent[intent]:
res = self._get_matching_result(text, cleaned_processed_text,
regex, intent, mapping)
Expand Down Expand Up @@ -300,14 +305,19 @@ def get_slots(self, text, intent):
slots = []
return slots

def _preprocess_text(self, string):
def _get_intent_stop_words(self, intent):
whitelist = self._stop_words_whitelist.get(intent, set())
return self._stop_words.difference(whitelist)

def _preprocess_text(self, string, intent):
"""Replaces stop words and characters that are tokenized out by
whitespaces"""
tokens = tokenize(string, self.language)
current_idx = 0
cleaned_string = ""
stop_words = self._get_intent_stop_words(intent)
for token in tokens:
if self.stop_words and normalize_token(token) in self.stop_words:
if stop_words and normalize_token(token) in stop_words:
token.value = "".join(" " for _ in range(len(token.value)))
prefix_length = token.start - current_idx
cleaned_string += "".join((" " for _ in range(prefix_length)))
Expand Down Expand Up @@ -352,18 +362,21 @@ def _get_matching_result(self, text, processed_text, regex, intent,
key=lambda s: s[RES_MATCH_RANGE][START])
return extraction_result(parsed_intent, parsed_slots)

def _generate_patterns(self, intent_utterances, entity_placeholders):
def _generate_patterns(self, intent, intent_utterances,
entity_placeholders):
unique_patterns = set()
patterns = []
stop_words = self._get_intent_stop_words(intent)
for utterance in intent_utterances:
pattern = self._utterance_to_pattern(
utterance, entity_placeholders)
utterance, stop_words, entity_placeholders)
if pattern not in unique_patterns:
unique_patterns.add(pattern)
patterns.append(pattern)
return patterns

def _utterance_to_pattern(self, utterance, entity_placeholders):
def _utterance_to_pattern(self, utterance, stop_words,
entity_placeholders):
slot_names_count = defaultdict(int)
pattern = []
for chunk in utterance[DATA]:
Expand All @@ -379,7 +392,7 @@ def _utterance_to_pattern(self, utterance, entity_placeholders):
else:
tokens = tokenize_light(chunk[TEXT], self.language)
pattern += [regex_escape(t.lower()) for t in tokens
if normalize(t) not in self.stop_words]
if normalize(t) not in stop_words]

pattern = r"^%s%s%s$" % (WHITESPACE_PATTERN,
WHITESPACE_PATTERN.join(pattern),
Expand Down Expand Up @@ -417,12 +430,18 @@ def from_path(cls, path, **shared):

def to_dict(self):
"""Returns a json-serializable dict"""
stop_words_whitelist = None
if self._stop_words_whitelist is not None:
stop_words_whitelist = {
intent: sorted(values)
for intent, values in iteritems(self._stop_words_whitelist)}
return {
"config": self.config.to_dict(),
"language_code": self.language,
"patterns": self.patterns,
"group_names_to_slot_names": self.group_names_to_slot_names,
"slot_names_to_entities": self.slot_names_to_entities
"slot_names_to_entities": self.slot_names_to_entities,
"stop_words_whitelist": stop_words_whitelist
}

@classmethod
Expand All @@ -439,6 +458,12 @@ def from_dict(cls, unit_dict, **shared):
parser.group_names_to_slot_names = unit_dict[
"group_names_to_slot_names"]
parser.slot_names_to_entities = unit_dict["slot_names_to_entities"]
if parser.fitted:
whitelist = unit_dict.get("stop_words_whitelist", dict())
# pylint:disable=protected-access
parser._stop_words_whitelist = {
intent: set(values) for intent, values in iteritems(whitelist)}
# pylint:enable=protected-access
return parser


Expand Down Expand Up @@ -487,3 +512,14 @@ def sort_key_fn(slot):
def _get_entity_name_placeholder(entity_label, language):
return "%%%s%%" % "".join(
tokenize_light(entity_label, language)).upper()


def _get_stop_words_whitelist(dataset, stop_words):
entity_values_per_intent = extract_entity_values(
dataset, apply_normalization=True)
stop_words_whitelist = dict()
for intent, entity_values in iteritems(entity_values_per_intent):
whitelist = stop_words.intersection(entity_values)
if whitelist:
stop_words_whitelist[intent] = whitelist
return stop_words_whitelist
60 changes: 60 additions & 0 deletions snips_nlu/tests/test_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import unicode_literals

import io
from unittest import TestCase

from snips_nlu.dataset import Dataset, validate_and_format_dataset
from snips_nlu.dataset.utils import extract_entity_values


class TestDatasetUtils(TestCase):
def test_should_extract_entity_values(self):
# Given
set_light_color_yaml = io.StringIO("""
---
type: intent
name: setLightColor
utterances:
- set the lights to [color](blue)
- change the light to [color](yellow) in the [room](bedroom)""")

turn_light_on_yaml = io.StringIO("""
---
type: intent
name: turnLightOn
utterances:
- turn the light on in the [room](kitchen)
- turn the [room](bathroom)'s lights on""")

color_yaml = io.StringIO("""
type: entity
name: color
values:
- [blue, cyan]
- red""")

room_yaml = io.StringIO("""
type: entity
name: room
values:
- garage
- [living room, main room]""")

dataset_files = [set_light_color_yaml, turn_light_on_yaml, color_yaml,
room_yaml]
dataset = Dataset.from_yaml_files("en", dataset_files).json
dataset = validate_and_format_dataset(dataset)

# When
entity_values = extract_entity_values(dataset,
apply_normalization=True)

# Then
expected_values = {
"setLightColor": {"blue", "yellow", "cyan", "red", "bedroom",
"garage", "living room", "main room", "kitchen",
"bathroom"},
"turnLightOn": {"bedroom", "garage", "living room", "main room",
"kitchen", "bathroom"}
}
self.assertDictEqual(expected_values, entity_values)
Loading

0 comments on commit b8466e7

Please sign in to comment.