Skip to content

Commit

Permalink
Remove intent's syntetic ID construct and use name (RasaHQ#10013)
Browse files Browse the repository at this point in the history
* Remove intent's syntetic ID construct and use name

Fixes RasaHQ#8974

Intent IDs were basically hashed from intent's labels (aka names?) and
were in number format. If we change this to a string to make sure
scientific notation does not lose precision, we essentially can use the
label in place. Intent labels are unique and much more human readable.
Does the job in a simpler way.
  • Loading branch information
Tayfun Sen authored Nov 2, 2021
1 parent 18ca40c commit d2eb3a7
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 154 deletions.
6 changes: 6 additions & 0 deletions changelog/8974.removal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Intent IDs sent with events (to kafka and elsewhere) have been removed, intent
names can be used instead (or if numerical values are needed for backwards
compatibility, one can also hash the names to get previous ID values, ie.
`hash(intent_name)` is the old ID values). Intent IDs have been removed because
they were providing no extra value and integers that large were problematic for
some event broker implementations.
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ event_broker:
sasl_username: username
sasl_password: password
sasl_mechanism: PLAIN
convert_intent_id_to_string: True
18 changes: 0 additions & 18 deletions rasa/core/brokers/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
ssl_check_hostname: bool = False,
security_protocol: Text = "SASL_PLAINTEXT",
loglevel: Union[int, Text] = logging.ERROR,
convert_intent_id_to_string: bool = False,
**kwargs: Any,
) -> None:
"""Kafka event broker.
Expand Down Expand Up @@ -69,8 +68,6 @@ def __init__(
security_protocol: Protocol used to communicate with brokers.
Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL.
loglevel: Logging level of the kafka logger.
convert_intent_id_to_string: Optional flag to configure whether intent ID's
are converted from an integer to a string.
"""
import kafka

Expand All @@ -87,7 +84,6 @@ def __init__(
self.ssl_certfile = ssl_certfile
self.ssl_keyfile = ssl_keyfile
self.ssl_check_hostname = ssl_check_hostname
self.convert_intent_id_to_string = convert_intent_id_to_string

logging.getLogger("kafka").setLevel(loglevel)

Expand All @@ -110,8 +106,6 @@ def publish(
retry_delay_in_seconds: float = 5,
) -> None:
"""Publishes events."""
if self.convert_intent_id_to_string:
event = self._convert_intent_id_to_string(event)
if self.producer is None:
self._create_producer()
connected = self.producer.bootstrap_connected()
Expand Down Expand Up @@ -205,17 +199,5 @@ def _publish(self, event: Dict[Text, Any]) -> None:
)
self.producer.send(self.topic, value=event, key=partition_key)

def _convert_intent_id_to_string(self, event: Dict[Text, Any]) -> Dict[Text, Any]:
if event.get("event", "") == "user" and "id" in event.get("parse_data", {}).get(
"intent", {}
):
event["parse_data"]["intent"]["id"] = str(
event["parse_data"]["intent"]["id"]
)
for idx, parse_data in enumerate(event["parse_data"]["intent_ranking"]):
parse_data["id"] = str(parse_data["id"])
event["parse_data"]["intent_ranking"][idx] = parse_data
return event

def _close(self) -> None:
self.producer.close()
9 changes: 2 additions & 7 deletions rasa/nlu/classifiers/diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def _predict_label(
self, predict_out: Optional[Dict[Text, tf.Tensor]]
) -> Tuple[Dict[Text, Any], List[Dict[Text, Any]]]:
"""Predicts the intent of the provided message."""
label: Dict[Text, Any] = {"name": None, "id": None, "confidence": 0.0}
label: Dict[Text, Any] = {"name": None, "confidence": 0.0}
label_ranking = []

if predict_out is None:
Expand All @@ -962,18 +962,13 @@ def _predict_label(
casted_message_sim: List[float] = message_sim.tolist() # np.float to float
top_label_idx = ranked_label_indices[0]
label = {
"id": hash(self.index_label_id_mapping[top_label_idx]),
"name": self.index_label_id_mapping[top_label_idx],
"confidence": casted_message_sim[top_label_idx],
}

ranking = [(idx, casted_message_sim[idx]) for idx in ranked_label_indices]
label_ranking = [
{
"id": hash(self.index_label_id_mapping[label_idx]),
"name": self.index_label_id_mapping[label_idx],
"confidence": score,
}
{"name": self.index_label_id_mapping[label_idx], "confidence": score,}
for label_idx, score in ranking
]

Expand Down
5 changes: 2 additions & 3 deletions rasa/nlu/selectors/response_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,12 +510,12 @@ def _resolve_intent_response_key(

# First check if the predicted label was the key itself
search_key = util.template_key_to_intent_response_key(key)
if hash(search_key) == label.get("id"):
if search_key == label.get("name"):
return search_key

# Otherwise loop over the responses to check if the text has a direct match
for response in responses:
if hash(response.get(TEXT, "")) == label.get("id"):
if response.get(TEXT, "") == label.get("name"):
return search_key
return None

Expand Down Expand Up @@ -578,7 +578,6 @@ def process(self, messages: List[Message]) -> List[Message]:
)
prediction_dict = {
RESPONSE_SELECTOR_PREDICTION_KEY: {
"id": top_label["id"],
RESPONSE_SELECTOR_RESPONSES_KEY: label_responses,
PREDICTED_CONFIDENCE_KEY: top_label[PREDICTED_CONFIDENCE_KEY],
INTENT_RESPONSE_KEY: label_intent_response_key,
Expand Down
6 changes: 1 addition & 5 deletions rasa/shared/utils/schemas/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@

INTENT = {
"type": "object",
"properties": {
"name": {"type": "string"},
"confidence": {"type": "number"},
"id": {"type": "number"},
},
"properties": {"name": {"type": "string"}, "confidence": {"type": "number"},},
}

RESPONSE_SCHEMA = {
Expand Down
54 changes: 9 additions & 45 deletions tests/core/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,7 @@ async def test_remote_action_utterances_with_none_values(
UserUttered(
text="hello",
parse_data={
"intent": {
"id": -4389344335148575888,
"name": "greet",
"confidence": 0.9604260921478271,
},
"intent": {"name": "greet", "confidence": 0.9604260921478271,},
"entities": [
{"entity": "city", "value": "London"},
{"entity": "count", "value": 1},
Expand All @@ -385,46 +381,14 @@ async def test_remote_action_utterances_with_none_values(
"message_id": "3f4c04602a4947098c574b107d3ccc50",
"metadata": {},
"intent_ranking": [
{
"id": -4389344335148575888,
"name": "greet",
"confidence": 0.9604260921478271,
},
{
"id": 7180145986630405383,
"name": "goodbye",
"confidence": 0.01835782080888748,
},
{
"id": 4246019067232216572,
"name": "deny",
"confidence": 0.011255578137934208,
},
{
"id": -4048707801696782560,
"name": "bot_challenge",
"confidence": 0.004019865766167641,
},
{
"id": -5942619264156239037,
"name": "affirm",
"confidence": 0.002524246694520116,
},
{
"id": 677880322645240870,
"name": "mood_great",
"confidence": 0.002214624546468258,
},
{
"id": -5973454296286367554,
"name": "chitchat",
"confidence": 0.0009614597074687481,
},
{
"id": -4598562678335233249,
"name": "mood_unhappy",
"confidence": 0.00024030178610701114,
},
{"name": "greet", "confidence": 0.9604260921478271,},
{"name": "goodbye", "confidence": 0.01835782080888748,},
{"name": "deny", "confidence": 0.011255578137934208,},
{"name": "bot_challenge", "confidence": 0.004019865766167641,},
{"name": "affirm", "confidence": 0.002524246694520116,},
{"name": "mood_great", "confidence": 0.002214624546468258,},
{"name": "chitchat", "confidence": 0.0009614597074687481,},
{"name": "mood_unhappy", "confidence": 0.00024030178610701114,},
],
"response_selector": {
"all_retrieval_intents": [],
Expand Down
49 changes: 0 additions & 49 deletions tests/core/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ async def test_kafka_broker_from_config():
topic="topic",
partition_by_sender=True,
security_protocol="SASL_PLAINTEXT",
convert_intent_id_to_string=True,
)

assert actual.url == expected.url
Expand All @@ -271,54 +270,6 @@ async def test_kafka_broker_from_config():
assert actual.sasl_mechanism == expected.sasl_mechanism
assert actual.topic == expected.topic
assert actual.partition_by_sender == expected.partition_by_sender
assert actual.convert_intent_id_to_string == expected.convert_intent_id_to_string


async def test_kafka_broker_convert_intent_id_to_string():
user_event = {
"timestamp": 1517821726.200036,
"metadata": {},
"parse_data": {
"entities": [],
"intent": {"confidence": 0.54, "name": "greet", "id": 7703045398849936579},
"message_id": "987654321",
"metadata": {},
"text": "/greet",
"intent_ranking": [
{"confidence": 0.54, "name": "greet", "id": 7703045398849936579},
{"confidence": 0.31, "name": "goodbye", "id": -5127945386715371244},
{"confidence": 0.15, "name": "default", "id": 1699173715362944540},
],
},
"event": "user",
"text": "/greet",
"input_channel": "rest",
"message_id": "987654321",
}
actual = KafkaEventBroker(
"localhost",
sasl_username="username",
sasl_password="password",
sasl_mechanism="PLAIN",
topic="topic",
partition_by_sender=True,
security_protocol="SASL_PLAINTEXT",
convert_intent_id_to_string=True,
)

converted_user_event = actual._convert_intent_id_to_string(user_event)
intent_ranking = user_event["parse_data"]["intent_ranking"]
converted_intent_ranking = converted_user_event["parse_data"]["intent_ranking"]

assert converted_user_event["parse_data"]["intent"]["id"] == str(
user_event["parse_data"]["intent"]["id"]
)
assert all(
converted_parse_data["id"] == str(parse_data["id"])
for parse_data, converted_parse_data in zip(
intent_ranking, converted_intent_ranking
)
)


@pytest.mark.parametrize(
Expand Down
1 change: 1 addition & 0 deletions tests/core/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ async def inner(file_name: Path, ignore_action_unlikely_intent: bool) -> Agent:
config = textwrap.dedent(
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
pipeline: []
policies:
- name: RulePolicy
restrict_rules: false
Expand Down
22 changes: 9 additions & 13 deletions tests/shared/core/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def test_event_executed_comparison(
UserUttered(
text="hello",
parse_data={
"intent": {"id": 2, "name": "greet", "confidence": 0.9604260921478271,},
"intent": {"name": "greet", "confidence": 0.9604260921478271,},
"entities": [
{"entity": "city", "value": "London"},
{"entity": "count", "value": 1},
Expand All @@ -701,18 +701,14 @@ def test_event_executed_comparison(
"message_id": "3f4c04602a4947098c574b107d3ccc50",
"metadata": {},
"intent_ranking": [
{"id": 2, "name": "greet", "confidence": 0.9604260921478271,},
{"id": 1, "name": "goodbye", "confidence": 0.01835782080888748,},
{"id": 0, "name": "deny", "confidence": 0.011255578137934208,},
{"id": 3, "name": "bot_challenge", "confidence": 0.004019865766167641,},
{"id": 4, "name": "affirm", "confidence": 0.002524246694520116,},
{"id": 5, "name": "mood_great", "confidence": 0.002214624546468258,},
{"id": 6, "name": "chitchat", "confidence": 0.0009614597074687481,},
{
"id": 7,
"name": "mood_unhappy",
"confidence": 0.00024030178610701114,
},
{"name": "greet", "confidence": 0.9604260921478271,},
{"name": "goodbye", "confidence": 0.01835782080888748,},
{"name": "deny", "confidence": 0.011255578137934208,},
{"name": "bot_challenge", "confidence": 0.004019865766167641,},
{"name": "affirm", "confidence": 0.002524246694520116,},
{"name": "mood_great", "confidence": 0.002214624546468258,},
{"name": "chitchat", "confidence": 0.0009614597074687481,},
{"name": "mood_unhappy", "confidence": 0.00024030178610701114,},
],
"response_selector": {
"all_retrieval_intents": [],
Expand Down
14 changes: 5 additions & 9 deletions tests/shared/core/test_slot_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_slot_mapping_entity_is_desired(slot_name: Text, expected: bool):
tracker = DialogueStateTracker("test_id", slots=domain.slots)
event = UserUttered(
text="I'm travelling to Vancouver.",
intent={"id": 1, "name": "inform", "confidence": 0.9604260921478271},
intent={"name": "inform", "confidence": 0.9604260921478271},
entities=[{"entity": "GPE", "value": "Vancouver", "role": "destination"}],
)
tracker.update(event, domain)
Expand All @@ -30,11 +30,7 @@ def test_slot_mapping_intent_is_desired(domain: Domain):
tracker = DialogueStateTracker("sender_id_test", slots=domain.slots)
event1 = UserUttered(
text="I'd like to book a restaurant for 2 people.",
intent={
"id": 1,
"name": "request_restaurant",
"confidence": 0.9604260921478271,
},
intent={"name": "request_restaurant", "confidence": 0.9604260921478271,},
entities=[{"entity": "number", "value": 2}],
)
tracker.update(event1, domain)
Expand All @@ -45,7 +41,7 @@ def test_slot_mapping_intent_is_desired(domain: Domain):

event2 = UserUttered(
text="Yes, 2 please",
intent={"id": 2, "name": "affirm", "confidence": 0.9604260921478271},
intent={"name": "affirm", "confidence": 0.9604260921478271},
entities=[{"entity": "number", "value": 2}],
)
tracker.update(event2, domain)
Expand All @@ -56,7 +52,7 @@ def test_slot_mapping_intent_is_desired(domain: Domain):

event3 = UserUttered(
text="Yes, please",
intent={"id": 3, "name": "affirm", "confidence": 0.9604260921478271},
intent={"name": "affirm", "confidence": 0.9604260921478271},
entities=[],
)
tracker.update(event3, domain)
Expand Down Expand Up @@ -95,7 +91,7 @@ def test_slot_mappings_ignored_intents_during_active_loop():
event1 = ActiveLoop("restaurant_form")
event2 = UserUttered(
text="The weather is sunny today",
intent={"id": 4, "name": "chitchat", "confidence": 0.9604260921478271},
intent={"name": "chitchat", "confidence": 0.9604260921478271},
entities=[],
)
tracker.update_with_events([event1, event2], domain)
Expand Down
8 changes: 4 additions & 4 deletions tests/shared/core/test_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,7 @@ def test_tracker_unique_fingerprint(domain: Domain):
event1 = UserUttered(
text="hello",
parse_data={
"intent": {"id": 2, "name": "greet", "confidence": 0.9604260921478271},
"intent": {"name": "greet", "confidence": 0.9604260921478271},
"entities": [
{"entity": "city", "value": "London"},
{"entity": "count", "value": 1},
Expand All @@ -1473,9 +1473,9 @@ def test_tracker_unique_fingerprint(domain: Domain):
"message_id": "3f4c04602a4947098c574b107d3ccc59",
"metadata": {},
"intent_ranking": [
{"id": 2, "name": "greet", "confidence": 0.9604260921478271},
{"id": 1, "name": "goodbye", "confidence": 0.01835782080888748},
{"id": 0, "name": "deny", "confidence": 0.011255578137934208},
{"name": "greet", "confidence": 0.9604260921478271},
{"name": "goodbye", "confidence": 0.01835782080888748},
{"name": "deny", "confidence": 0.011255578137934208},
],
},
)
Expand Down

0 comments on commit d2eb3a7

Please sign in to comment.