Skip to content

Commit ea17a3e

Browse files
author
Thomas Kaltenbrunner
committed
Merged PR 3136: improve qa matching if mutiple qa pairs have same number of entities
in this case subset on the training examples which have the same number of matches and run those through the embedding matcher. take the sample with the highest score as the match Related work items: #5576
2 parents ec73d21 + 30b8ae5 commit ea17a3e

File tree

5 files changed

+32
-16
lines changed

5 files changed

+32
-16
lines changed

src/embedding/chat_process.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,20 @@ async def chat_request(self, msg: ait_c.ChatRequestMessage):
6161
self.cls.update_w2v(vecs)
6262
yPred, yProbs = self.cls.predict(x_tokens_testset)
6363
if yProbs[0] < THRESHOLD or len(x_tokens_testset) < 3:
64-
matched_answer = self.entity_wrapper.match_entities(
64+
matched_answers = self.entity_wrapper.match_entities(
6565
msg.question)
66-
self.logger.info("matched_entities: {}".format(matched_answer))
67-
if matched_answer:
68-
self.logger.info("substituting {} for entity match {}".format(
69-
yPred, matched_answer))
70-
yPred = [matched_answer]
71-
yProbs = [ENTITY_MATCH_PROBA]
66+
self.logger.info("matched_entities: {}".format(matched_answers))
67+
if matched_answers:
68+
if len(matched_answers) > 1:
69+
train_idx = [e[0] for e in matched_answers]
70+
yPred, yProbs = self.cls.predict(x_tokens_testset, subset_idx=train_idx)
71+
self.logger.info("multiple entity matches {}; pick {}".format(
72+
matched_answers, yPred))
73+
else:
74+
self.logger.info("substituting {} for entity match {}".format(
75+
yPred, matched_answers))
76+
yPred = [matched_answers[0][1]]
77+
yProbs = [ENTITY_MATCH_PROBA]
7278
resp = ait_c.ChatResponseMessage(msg, yPred[0], float(yProbs[0]))
7379
return resp
7480

src/embedding/entity_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def match_entities(self, test_q):
7878
if e not in ['the'] and e in test_match:
7979
num_matches += 1
8080
if num_matches > max_matches:
81+
max_matches = num_matches
82+
matched_labels = [(i, self.train_labels[i])]
83+
elif num_matches == max_matches and max_matches > 0:
8184
matched_labels.append((i, self.train_labels[i]))
8285
return matched_labels
8386

src/embedding/tests/test_embedding_chat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ async def mocked_chat(mocker, loop):
6464
'start': 8,
6565
'end': 18
6666
}]]
67+
6768
chat.entity_wrapper.train_labels = ["You said London today",
6869
"You said Paris Fred Bloggs"]
6970

@@ -127,7 +128,7 @@ async def test_chat_request_entity_no_match2(mocker, mocked_chat):
127128

128129
msg = ait_c.ChatRequestMessage("This question has entities London today in it", None, None, update_state=True)
129130
response = await mocked_chat.chat_request(msg)
130-
assert response.answer[0][1] == "You said London today"
131+
assert response.answer == "You said London today"
131132
assert response.score == embedding.chat_process.ENTITY_MATCH_PROBA
132133
assert response.topic_out is None
133134
assert response.history is None

src/embedding/tests/test_embedding_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ async def mocked_train(mocker, loop):
4545
training.entity_wrapper,
4646
"get_from_er_server",
4747
new=get_from_er_server)
48+
4849
training.entity_wrapper.train_entities = [
4950
[{
5051
'category': 'sys.places',
@@ -68,6 +69,7 @@ async def mocked_train(mocker, loop):
6869
'start': 8,
6970
'end': 18
7071
}]]
72+
7173
training.entity_wrapper.train_labels = ["You said London today",
7274
"You said Paris Fred Bloggs"]
7375

@@ -126,7 +128,6 @@ async def test_er_match_entities_2(mocked_train):
126128
assert matched_label[0][1] == "You said Paris Fred Bloggs"
127129

128130

129-
130131
async def test_train_success(mocked_train, mocker):
131132
DUMMY_AIID = "123456"
132133
DUMMY_TRAINING_DATA = """

src/embedding/text_classifier_class.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,26 @@ def fit(self, X, y):
110110
self.X = X
111111
self.classes = list(set(y))
112112

113-
def predict(self, X, scale_probas=False):
113+
def predict(self, X, scale_probas=False, subset_idx=None):
114+
if subset_idx:
115+
train_x = self.X_tfidf[subset_idx]
116+
train_y = self.y[subset_idx]
117+
else:
118+
train_x = self.X_tfidf
119+
train_y = self.y
114120
target_tfidf = self.vectorizer.transform(X)
115121
target_tfidf = target_tfidf - self.pca.components_[0]
116-
117122
# compute cosine similarity
118-
cossim = np.dot(target_tfidf, self.X_tfidf.T) / (
119-
np.outer(np.linalg.norm(target_tfidf, axis=1), np.linalg.norm(self.X_tfidf, axis=1)))
123+
cossim = np.dot(target_tfidf, train_x.T) / (
124+
np.outer(np.linalg.norm(target_tfidf, axis=1), np.linalg.norm(train_x, axis=1)))
120125
# self.logger.info("cossim: {}".format(cossim))
121126
cossim = np.where(cossim < 0., 0., cossim)
122-
127+
if subset_idx:
128+
self.logger.info("cossims: {}".format(cossim))
123129
# most similar vector is the predicted class
124130
preds = np.argmax(cossim, 1)
125-
preds = [self.y[i] for i in preds]
131+
preds = [train_y[i] for i in preds]
126132
probs = self.downscale_probas(np.max(cossim, axis=1))
127-
128133
return preds, list(probs)
129134

130135
def save_model(self, file_path: Path):

0 commit comments

Comments
 (0)