Skip to content

Commit 041383f

Browse files
author
Thomas Kaltenbrunner
committed
Merged PR 3308: save tokenised training data during training; minor optimisations
save tokenised training data during training; minor optimisations Related work items: #5885
2 parents b845236 + 8cab38c commit 041383f

File tree

5 files changed

+32
-31
lines changed

5 files changed

+32
-31
lines changed

src/embedding/chat_process.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from pathlib import Path
5-
5+
import time
66
import aiohttp
77

88
import ai_training.chat_process as ait_c
@@ -53,22 +53,30 @@ async def chat_request(self, msg: ait_c.ChatRequestMessage):
5353
await self.setup_chat_session()
5454

5555
# tokenize
56+
t_start = time.time()
5657
x_tokens_testset = [
5758
await self.entity_wrapper.tokenize(msg.question, sw_size='xlarge')
5859
]
5960
self.logger.debug("x_tokens_testset: {}".format(x_tokens_testset))
6061
self.logger.debug("x_tokens_testset: {}".format(
6162
len(x_tokens_testset[0])))
63+
self.logger.debug("tokenizing: {}s".format(time.time() - t_start))
6264

6365
# get question entities
66+
t_start = time.time()
6467
msg_entities = await self.entity_wrapper.extract_entities(msg.question)
6568
self.logger.debug("msg_entities: {}".format(msg_entities))
69+
self.logger.debug("msg_entities: {}s".format(time.time() - t_start))
6670

6771
# get string match
72+
t_start = time.time()
6873
sm_pred, sm_prob = await self.get_string_match(msg, msg_entities, x_tokens_testset)
74+
self.logger.debug("string_match: {}s".format(time.time() - t_start))
6975

7076
# entity matcher
77+
t_start = time.time()
7178
er_pred, er_prob = await self.get_entity_match(msg, msg_entities, x_tokens_testset)
79+
self.logger.debug("entity_match: {}s".format(time.time() - t_start))
7280

7381
# if SM proba larger take that
7482
if sm_prob[0] > er_prob[0] and sm_prob[0] > STRING_PROBA_THRES:
@@ -80,8 +88,10 @@ async def chat_request(self, msg: ait_c.ChatRequestMessage):
8088
self.logger.info("er wins: {}".format(y_pred))
8189
# if both ER and SM fail completely - EMB to the rescue!
8290
elif x_tokens_testset[0][0] != 'UNK':
91+
t_start = time.time()
8392
y_pred, y_prob = await self.get_embedding_match(msg, msg_entities, x_tokens_testset)
8493
self.logger.info("default emb: {}".format(y_pred))
94+
self.logger.debug("embedding: {}s".format(time.time() - t_start))
8595
else:
8696
y_pred = [""]
8797
y_prob = [0.0]
@@ -178,4 +188,3 @@ async def setup_chat_session(self):
178188
self.cls.load_model(ai_path / MODEL_FILE)
179189
self.entity_wrapper.load_data(ai_path / DATA_FILE)
180190
self.string_match.load_train_data(ai_path / TRAIN_FILE)
181-
await self.string_match.tokenize_train_data()

src/embedding/entity_wrapper.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,8 @@ def find_matches(self, train_ents, test_match):
119119
for i, tr_ents in train_ents:
120120
num_matches = 0
121121
self.logger.debug("train sample ents: {}".format(tr_ents))
122-
for ent in tr_ents:
123-
tmp_ent = self.split_entities(ent)
124-
for e in tmp_ent:
125-
if e not in ['the'] and e in test_match:
126-
num_matches += 1
122+
num_matches += sum(1 if e not in ['the'] and e in test_match else 0
123+
for ent in tr_ents for e in self.split_entities(ent))
127124
if num_matches > max_matches:
128125
max_matches = num_matches
129126
matched_labels = [(i, self.train_labels[i])]

src/embedding/string_match.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@ class StringMatch:
1010
def __init__(self, entity_wrapper):
1111
self.logger = logging.getLogger('string_match')
1212
self.train_data = None
13-
self.tok_train = []
13+
self.tok_train = None
1414
self.entity_wrapper = entity_wrapper
1515
self.stopword_size = 'small'
1616
self.filter_entities = 'False'
17+
self.custom_ents_samples = None
1718

1819
def load_train_data(self, file_path):
1920
with file_path.open('rb') as f:
20-
self.train_data = dill.load(f)
21-
self.tok_train = []
21+
tmp = dill.load(f)
22+
self.train_data = tmp[0]
23+
self.tok_train = tmp[1]
2224

2325
def save_train_data(self, data, file_name):
2426
if not isinstance(data, list):
@@ -28,14 +30,6 @@ def save_train_data(self, data, file_name):
2830
with open(file_name, 'wb') as f:
2931
dill.dump(data, f)
3032

31-
async def tokenize_train_data(self):
32-
for q in self.train_data:
33-
tok = await self.entity_wrapper.tokenize(
34-
q[0],
35-
filter_ents=self.filter_entities,
36-
sw_size=self.stopword_size)
37-
self.tok_train.append(tok)
38-
3933
async def get_string_match(self, q, subset_idx=None,
4034
all_larger_zero=False):
4135
self.logger.info("searching for word matches")
@@ -49,17 +43,9 @@ async def get_string_match(self, q, subset_idx=None,
4943
tok_q = await self.entity_wrapper.tokenize(
5044
q, filter_ents=self.filter_entities, sw_size=self.stopword_size)
5145

52-
# search for intent-like entities first
53-
if "@" in q:
54-
match_probas = [
55-
1.0 if "@" in t[0] else 0.0 for t in self.train_data
56-
]
57-
# otherwise do string match
58-
else:
59-
match_probas = [
60-
self.__jaccard_similarity(tok_q, t)
61-
if '@' not in ' '.join(t) else 0.0 for t in tok_train
62-
]
46+
match_probas = [
47+
self.__jaccard_similarity(tok_q, t) for t in tok_train
48+
]
6349

6450
self.logger.info("match_probas: {}".format(match_probas))
6551
max_proba = max(match_probas)

src/embedding/tests/test_embedding_chat.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ async def mocked_chat(mocker, loop):
9999
("This is London today for entity match", "entity wins with London today"),
100100
("This is a perfect string match", "string wins"),
101101
("This is the question for embedding word1 word2", "embedding wins")]
102+
chat.string_match.tok_train = [
103+
["this", "be", "london", "today", "for", "entity", "match"],
104+
["this", "be", "perfect", "string", "match"],
105+
["this", "be", "question", "for", "embedding", "word1", "word2"]
106+
]
102107
# mock out the load methods
103108
mocker.patch("embedding.text_classifier_class.EmbeddingComparison.load_model")
104109
mocker.patch.object(chat.entity_wrapper, "load_data")

src/embedding/training_process.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,6 @@ async def train(self, msg, topic: ait.Topic, callback):
9797
temp_data_file = tempdir_path / DATA_FILE
9898
temp_train_file = tempdir_path / TRAIN_FILE
9999

100-
self.string_match.save_train_data(q_and_a, temp_train_file)
101-
102100
self.logger.info("Extracting entities...")
103101
q_entities, a_entities = [], []
104102
for question, answer in q_and_a:
@@ -114,10 +112,16 @@ async def train(self, msg, topic: ait.Topic, callback):
114112
"Entities saved to {}, tokenizing...".format(temp_data_file))
115113

116114
x_tokens = []
115+
x_tokens_save = []
117116
for question in x:
118117
tokens = await self.entity_wrapper.tokenize(question, sw_size='xlarge')
119118
x_tokens.append(tokens)
119+
tokens = await self.entity_wrapper.tokenize(question,
120+
sw_size='small',
121+
filter_ents='False')
122+
x_tokens_save.append(tokens)
120123
self.report_progress(0.3)
124+
self.string_match.save_train_data([q_and_a, x_tokens_save], temp_train_file)
121125

122126
x_tokens_set = list(set([w for l in x_tokens for w in l]))
123127

0 commit comments

Comments
 (0)