Skip to content

Commit b845236

Browse files
author
Thomas Kaltenbrunner
committed
Merged PR 3278: improve qa-matcher
- add string match - rewrite logic in finding the answer between different qa-matcher components (emb, entity-match, string-match) - improve entity matcher - add and revise unittests - add option to run embedding on subset of training cases - add default option of empty string + 0.0 probability if no matches are found (catches many random sentences such as "how are you",...) Related work items: #5540, #5691, #5692, #5777, #5794
2 parents d3aa606 + ca93ab5 commit b845236

File tree

12 files changed

+406
-152
lines changed

12 files changed

+406
-152
lines changed

.idea/embedding.iml

Lines changed: 0 additions & 11 deletions
This file was deleted.

.idea/misc.xml

Lines changed: 0 additions & 4 deletions
This file was deleted.

.idea/modules.xml

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/docker-compose.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ services:
99
volumes:
1010
- ../../word2vec/src/word2vec:/src/word2vec
1111
- ../../word2vec/src/datasets:/datasets
12+
# - /mnt/storage/thomas/word_vectors/:/datasets
1213
ports:
1314
- 10002:9090
1415
environment:
@@ -37,5 +38,5 @@ services:
3738
- AI_TRAIN_CAPACITY=1
3839
- EMB_SERVER_PORT=9090
3940
# - W2V_SERVER_URL=http://ai-word2vec:9090
40-
# - W2V_SERVER_URL=http://dev-gpu1.hutoma.ai:9090
41+
# - W2V_SERVER_URL=http://10.8.0.26:9090
4142
- W2V_SERVER_URL=http://10.181.0.4:30100

src/embedding/chat_process.py

Lines changed: 98 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,17 @@
77

88
import ai_training.chat_process as ait_c
99
from embedding.entity_wrapper import EntityWrapper
10-
from embedding.string_match import StringMatch
1110
from embedding.text_classifier_class import EmbeddingComparison
1211
from embedding.word2vec_client import Word2VecClient
1312
from embedding.svc_config import SvcConfig
14-
13+
from embedding.string_match import StringMatch
1514

1615
MODEL_FILE = "model.pkl"
1716
DATA_FILE = "data.pkl"
1817
TRAIN_FILE = "train.pkl"
1918

20-
THRESHOLD = 0.5
2119
ENTITY_MATCH_PROBA = 0.7
22-
STRING_PROBA_THRES = 0.8
20+
STRING_PROBA_THRES = 0.45
2321

2422

2523
def _get_logger():
@@ -28,7 +26,6 @@ def _get_logger():
2826

2927

3028
class EmbeddingChatProcessWorker(ait_c.ChatProcessWorkerABC):
31-
3229
def __init__(self, pool, asyncio_loop, aiohttp_client_session=None):
3330
super().__init__(pool, asyncio_loop)
3431
self.chat_args = None
@@ -38,9 +35,10 @@ def __init__(self, pool, asyncio_loop, aiohttp_client_session=None):
3835
aiohttp_client_session = aiohttp.ClientSession()
3936
self.aiohttp_client = aiohttp_client_session
4037
config = SvcConfig.get_instance()
41-
self.w2v_client = Word2VecClient(
42-
config.w2v_server_url, self.aiohttp_client)
43-
self.entity_wrapper = EntityWrapper(config.er_server_url, self.aiohttp_client)
38+
self.w2v_client = Word2VecClient(config.w2v_server_url,
39+
self.aiohttp_client)
40+
self.entity_wrapper = EntityWrapper(config.er_server_url,
41+
self.aiohttp_client)
4442
self.string_match = StringMatch(self.entity_wrapper)
4543
self.cls = None
4644

@@ -54,55 +52,116 @@ async def chat_request(self, msg: ait_c.ChatRequestMessage):
5452
if msg.update_state:
5553
await self.setup_chat_session()
5654

55+
# tokenize
5756
x_tokens_testset = [
58-
await self.entity_wrapper.tokenize(msg.question)
57+
await self.entity_wrapper.tokenize(msg.question, sw_size='xlarge')
5958
]
60-
self.logger.info("x_tokens_testset: {}".format(x_tokens_testset))
59+
self.logger.debug("x_tokens_testset: {}".format(x_tokens_testset))
60+
self.logger.debug("x_tokens_testset: {}".format(
61+
len(x_tokens_testset[0])))
6162

6263
# get question entities
6364
msg_entities = await self.entity_wrapper.extract_entities(msg.question)
6465
self.logger.debug("msg_entities: {}".format(msg_entities))
6566

6667
# get string match
6768
sm_pred, sm_prob = await self.get_string_match(msg, msg_entities, x_tokens_testset)
68-
if sm_prob[0] > STRING_PROBA_THRES:
69-
yPred = sm_pred
70-
yProbs = sm_prob
69+
70+
# entity matcher
71+
er_pred, er_prob = await self.get_entity_match(msg, msg_entities, x_tokens_testset)
72+
73+
# if SM proba larger take that
74+
if sm_prob[0] > er_prob[0] and sm_prob[0] > STRING_PROBA_THRES:
75+
y_pred, y_prob = sm_pred, sm_prob
76+
self.logger.info("sm wins: {}".format(y_pred))
77+
# otherwise take ER result if there is any
78+
elif er_prob[0] > 0.:
79+
y_pred, y_prob = er_pred, er_prob
80+
self.logger.info("er wins: {}".format(y_pred))
81+
# if both ER and SM fail completely - EMB to the rescue!
82+
elif x_tokens_testset[0][0] != 'UNK':
83+
y_pred, y_prob = await self.get_embedding_match(msg, msg_entities, x_tokens_testset)
84+
self.logger.info("default emb: {}".format(y_pred))
7185
else:
72-
unique_tokens = list(set([w for l in x_tokens_testset for w in l]))
73-
unk_tokens = self.cls.get_unknown_words(unique_tokens)
74-
75-
vecs = await self.w2v_client.get_vectors_for_words(unk_tokens)
76-
77-
self.cls.update_w2v(vecs)
78-
yPred, yProbs = self.cls.predict(x_tokens_testset)
79-
if yProbs[0] < THRESHOLD or len(x_tokens_testset) < 3:
80-
matched_answers = self.entity_wrapper.match_entities(
81-
msg.question)
82-
self.logger.info("matched_entities: {}".format(matched_answers))
83-
if matched_answers:
84-
if len(matched_answers) > 1:
85-
train_idx = [e[0] for e in matched_answers]
86-
yPred, yProbs = self.cls.predict(x_tokens_testset, subset_idx=train_idx)
87-
self.logger.info("multiple entity matches {}; pick {}".format(
88-
matched_answers, yPred))
89-
else:
90-
self.logger.info("substituting {} for entity match {}".format(
91-
yPred, matched_answers))
92-
yPred = [matched_answers[0][1]]
93-
yProbs = [ENTITY_MATCH_PROBA]
94-
resp = ait_c.ChatResponseMessage(msg, yPred[0], float(yProbs[0]))
86+
y_pred = [""]
87+
y_prob = [0.0]
88+
89+
resp = ait_c.ChatResponseMessage(msg, y_pred[0], float(y_prob[0]))
9590
return resp
9691

92+
async def get_embedding_match(self, msg, msg_entities, x_tokens_testset):
93+
# get new word embeddings
94+
unique_tokens = list(set([w for l in x_tokens_testset for w in l]))
95+
unk_tokens = self.cls.get_unknown_words(unique_tokens)
96+
if len(unk_tokens) > 0:
97+
unk_words = await self.w2v_client.get_unknown_words(unk_tokens)
98+
self.logger.debug("unknown words: {}".format(unk_words))
99+
if len(unk_words) > 0:
100+
unk_tokens = [w for w in unk_tokens if w not in unk_words]
101+
x_tokens_testset = [[w for w in s if w not in unk_words]
102+
for s in x_tokens_testset]
103+
if len(unk_tokens) > 0:
104+
vecs = await self.w2v_client.get_vectors_for_words(
105+
unk_tokens)
106+
self.cls.update_w2v(vecs)
107+
self.logger.debug("final tok set: {}".format(x_tokens_testset))
108+
# get embedding match
109+
y_pred, y_prob = self.cls.predict(x_tokens_testset)
110+
y_prob = [max(0., y_prob[0] - 0.15)]
111+
return y_pred, y_prob
112+
113+
async def get_entity_match(self, msg, msg_entities, x_tokens_testset):
114+
matched_answers = self.entity_wrapper.match_entities(
115+
msg.question, msg_entities)
116+
if len(matched_answers) == 1:
117+
er_pred = [matched_answers[0][1]]
118+
er_prob = [ENTITY_MATCH_PROBA]
119+
elif len(matched_answers) > 1:
120+
er_idxs, _ = zip(*matched_answers)
121+
if not any(
122+
[self.string_match.train_data[i][0] == 'UNK'
123+
for i in er_idxs]):
124+
er_pred, er_prob = self.cls.predict(
125+
x_tokens_testset, subset_idx=er_idxs)
126+
er_prob = [ENTITY_MATCH_PROBA] # min(0.99, er_prob[0])
127+
128+
self.logger.debug("er_pred: {} er_prob: {}".format(
129+
er_pred, er_prob))
130+
else:
131+
er_pred = ['']
132+
er_prob = [0.0]
133+
else:
134+
er_pred, er_prob = [''], [0.0]
135+
return er_pred, er_prob
136+
97137
async def get_string_match(self, msg, msg_entities, x_tokens_testset):
98138
# get string match
99139
sm_proba, sm_preds = await self.string_match.get_string_match(
100140
msg.question)
101141
if len(sm_preds) > 1:
102142
sm_idxs, _ = zip(*sm_preds)
103-
sm_pred, sm_prob = self.cls.predict(
104-
x_tokens_testset, subset_idx=sm_idxs)
105-
sm_prob = [sm_proba]
143+
self.logger.debug("sm_idxs: {}".format(sm_idxs))
144+
matched_answers = self.entity_wrapper.match_entities(
145+
msg.question, msg_entities, subset_idxs=sm_idxs)
146+
if len(matched_answers) == 1:
147+
sm_pred = [matched_answers[0][1]]
148+
sm_prob = [ENTITY_MATCH_PROBA]
149+
elif len(matched_answers) > 1:
150+
sm_idxs, _ = zip(*matched_answers)
151+
if not any([
152+
self.string_match.train_data[i][0] == 'UNK'
153+
for i in sm_idxs
154+
]):
155+
sm_pred, sm_prob = self.cls.predict(
156+
x_tokens_testset, subset_idx=sm_idxs)
157+
sm_prob = [ENTITY_MATCH_PROBA + 0.1
158+
] # min(0.99, sm_prob[0])
159+
else:
160+
sm_pred = ['']
161+
sm_prob = [0.0]
162+
else:
163+
sm_pred = ['']
164+
sm_prob = [0.0]
106165
elif len(sm_preds) == 1:
107166
sm_idxs, sm_pred, sm_prob = [sm_preds[0][0]], [sm_preds[0][1]], [
108167
sm_proba

0 commit comments

Comments
 (0)