77
88import ai_training .chat_process as ait_c
99from embedding .entity_wrapper import EntityWrapper
10- from embedding .string_match import StringMatch
1110from embedding .text_classifier_class import EmbeddingComparison
1211from embedding .word2vec_client import Word2VecClient
1312from embedding .svc_config import SvcConfig
14-
13+ from embedding . string_match import StringMatch
1514
1615MODEL_FILE = "model.pkl"
1716DATA_FILE = "data.pkl"
1817TRAIN_FILE = "train.pkl"
1918
20- THRESHOLD = 0.5
2119ENTITY_MATCH_PROBA = 0.7
22- STRING_PROBA_THRES = 0.8
20+ STRING_PROBA_THRES = 0.45
2321
2422
2523def _get_logger ():
@@ -28,7 +26,6 @@ def _get_logger():
2826
2927
3028class EmbeddingChatProcessWorker (ait_c .ChatProcessWorkerABC ):
31-
3229 def __init__ (self , pool , asyncio_loop , aiohttp_client_session = None ):
3330 super ().__init__ (pool , asyncio_loop )
3431 self .chat_args = None
@@ -38,9 +35,10 @@ def __init__(self, pool, asyncio_loop, aiohttp_client_session=None):
3835 aiohttp_client_session = aiohttp .ClientSession ()
3936 self .aiohttp_client = aiohttp_client_session
4037 config = SvcConfig .get_instance ()
41- self .w2v_client = Word2VecClient (
42- config .w2v_server_url , self .aiohttp_client )
43- self .entity_wrapper = EntityWrapper (config .er_server_url , self .aiohttp_client )
38+ self .w2v_client = Word2VecClient (config .w2v_server_url ,
39+ self .aiohttp_client )
40+ self .entity_wrapper = EntityWrapper (config .er_server_url ,
41+ self .aiohttp_client )
4442 self .string_match = StringMatch (self .entity_wrapper )
4543 self .cls = None
4644
@@ -54,55 +52,116 @@ async def chat_request(self, msg: ait_c.ChatRequestMessage):
5452 if msg .update_state :
5553 await self .setup_chat_session ()
5654
55+ # tokenize
5756 x_tokens_testset = [
58- await self .entity_wrapper .tokenize (msg .question )
57+ await self .entity_wrapper .tokenize (msg .question , sw_size = 'xlarge' )
5958 ]
60- self .logger .info ("x_tokens_testset: {}" .format (x_tokens_testset ))
59+ self .logger .debug ("x_tokens_testset: {}" .format (x_tokens_testset ))
60+ self .logger .debug ("x_tokens_testset: {}" .format (
61+ len (x_tokens_testset [0 ])))
6162
6263 # get question entities
6364 msg_entities = await self .entity_wrapper .extract_entities (msg .question )
6465 self .logger .debug ("msg_entities: {}" .format (msg_entities ))
6566
6667 # get string match
6768 sm_pred , sm_prob = await self .get_string_match (msg , msg_entities , x_tokens_testset )
68- if sm_prob [0 ] > STRING_PROBA_THRES :
69- yPred = sm_pred
70- yProbs = sm_prob
69+
70+ # entity matcher
71+ er_pred , er_prob = await self .get_entity_match (msg , msg_entities , x_tokens_testset )
72+
73+ # if SM proba larger take that
74+ if sm_prob [0 ] > er_prob [0 ] and sm_prob [0 ] > STRING_PROBA_THRES :
75+ y_pred , y_prob = sm_pred , sm_prob
76+ self .logger .info ("sm wins: {}" .format (y_pred ))
77+ # otherwise take ER result if there is any
78+ elif er_prob [0 ] > 0. :
79+ y_pred , y_prob = er_pred , er_prob
80+ self .logger .info ("er wins: {}" .format (y_pred ))
81+ # if both ER and SM fail completely - EMB to the rescue!
82+ elif x_tokens_testset [0 ][0 ] != 'UNK' :
83+ y_pred , y_prob = await self .get_embedding_match (msg , msg_entities , x_tokens_testset )
84+ self .logger .info ("default emb: {}" .format (y_pred ))
7185 else :
72- unique_tokens = list (set ([w for l in x_tokens_testset for w in l ]))
73- unk_tokens = self .cls .get_unknown_words (unique_tokens )
74-
75- vecs = await self .w2v_client .get_vectors_for_words (unk_tokens )
76-
77- self .cls .update_w2v (vecs )
78- yPred , yProbs = self .cls .predict (x_tokens_testset )
79- if yProbs [0 ] < THRESHOLD or len (x_tokens_testset ) < 3 :
80- matched_answers = self .entity_wrapper .match_entities (
81- msg .question )
82- self .logger .info ("matched_entities: {}" .format (matched_answers ))
83- if matched_answers :
84- if len (matched_answers ) > 1 :
85- train_idx = [e [0 ] for e in matched_answers ]
86- yPred , yProbs = self .cls .predict (x_tokens_testset , subset_idx = train_idx )
87- self .logger .info ("multiple entity matches {}; pick {}" .format (
88- matched_answers , yPred ))
89- else :
90- self .logger .info ("substituting {} for entity match {}" .format (
91- yPred , matched_answers ))
92- yPred = [matched_answers [0 ][1 ]]
93- yProbs = [ENTITY_MATCH_PROBA ]
94- resp = ait_c .ChatResponseMessage (msg , yPred [0 ], float (yProbs [0 ]))
86+ y_pred = ["" ]
87+ y_prob = [0.0 ]
88+
89+ resp = ait_c .ChatResponseMessage (msg , y_pred [0 ], float (y_prob [0 ]))
9590 return resp
9691
92+ async def get_embedding_match (self , msg , msg_entities , x_tokens_testset ):
93+ # get new word embeddings
94+ unique_tokens = list (set ([w for l in x_tokens_testset for w in l ]))
95+ unk_tokens = self .cls .get_unknown_words (unique_tokens )
96+ if len (unk_tokens ) > 0 :
97+ unk_words = await self .w2v_client .get_unknown_words (unk_tokens )
98+ self .logger .debug ("unknown words: {}" .format (unk_words ))
99+ if len (unk_words ) > 0 :
100+ unk_tokens = [w for w in unk_tokens if w not in unk_words ]
101+ x_tokens_testset = [[w for w in s if w not in unk_words ]
102+ for s in x_tokens_testset ]
103+ if len (unk_tokens ) > 0 :
104+ vecs = await self .w2v_client .get_vectors_for_words (
105+ unk_tokens )
106+ self .cls .update_w2v (vecs )
107+ self .logger .debug ("final tok set: {}" .format (x_tokens_testset ))
108+ # get embedding match
109+ y_pred , y_prob = self .cls .predict (x_tokens_testset )
110+ y_prob = [max (0. , y_prob [0 ] - 0.15 )]
111+ return y_pred , y_prob
112+
113+ async def get_entity_match (self , msg , msg_entities , x_tokens_testset ):
114+ matched_answers = self .entity_wrapper .match_entities (
115+ msg .question , msg_entities )
116+ if len (matched_answers ) == 1 :
117+ er_pred = [matched_answers [0 ][1 ]]
118+ er_prob = [ENTITY_MATCH_PROBA ]
119+ elif len (matched_answers ) > 1 :
120+ er_idxs , _ = zip (* matched_answers )
121+ if not any (
122+ [self .string_match .train_data [i ][0 ] == 'UNK'
123+ for i in er_idxs ]):
124+ er_pred , er_prob = self .cls .predict (
125+ x_tokens_testset , subset_idx = er_idxs )
126+ er_prob = [ENTITY_MATCH_PROBA ] # min(0.99, er_prob[0])
127+
128+ self .logger .debug ("er_pred: {} er_prob: {}" .format (
129+ er_pred , er_prob ))
130+ else :
131+ er_pred = ['' ]
132+ er_prob = [0.0 ]
133+ else :
134+ er_pred , er_prob = ['' ], [0.0 ]
135+ return er_pred , er_prob
136+
97137 async def get_string_match (self , msg , msg_entities , x_tokens_testset ):
98138 # get string match
99139 sm_proba , sm_preds = await self .string_match .get_string_match (
100140 msg .question )
101141 if len (sm_preds ) > 1 :
102142 sm_idxs , _ = zip (* sm_preds )
103- sm_pred , sm_prob = self .cls .predict (
104- x_tokens_testset , subset_idx = sm_idxs )
105- sm_prob = [sm_proba ]
143+ self .logger .debug ("sm_idxs: {}" .format (sm_idxs ))
144+ matched_answers = self .entity_wrapper .match_entities (
145+ msg .question , msg_entities , subset_idxs = sm_idxs )
146+ if len (matched_answers ) == 1 :
147+ sm_pred = [matched_answers [0 ][1 ]]
148+ sm_prob = [ENTITY_MATCH_PROBA ]
149+ elif len (matched_answers ) > 1 :
150+ sm_idxs , _ = zip (* matched_answers )
151+ if not any ([
152+ self .string_match .train_data [i ][0 ] == 'UNK'
153+ for i in sm_idxs
154+ ]):
155+ sm_pred , sm_prob = self .cls .predict (
156+ x_tokens_testset , subset_idx = sm_idxs )
157+ sm_prob = [ENTITY_MATCH_PROBA + 0.1
158+ ] # min(0.99, sm_prob[0])
159+ else :
160+ sm_pred = ['' ]
161+ sm_prob = [0.0 ]
162+ else :
163+ sm_pred = ['' ]
164+ sm_prob = [0.0 ]
106165 elif len (sm_preds ) == 1 :
107166 sm_idxs , sm_pred , sm_prob = [sm_preds [0 ][0 ]], [sm_preds [0 ][1 ]], [
108167 sm_proba
0 commit comments