@@ -10,15 +10,17 @@ class StringMatch:
1010 def __init__ (self , entity_wrapper ):
1111 self .logger = logging .getLogger ('string_match' )
1212 self .train_data = None
13- self .tok_train = []
13+ self .tok_train = None
1414 self .entity_wrapper = entity_wrapper
1515 self .stopword_size = 'small'
1616 self .filter_entities = 'False'
17+ self .custom_ents_samples = None
1718
1819 def load_train_data (self , file_path ):
1920 with file_path .open ('rb' ) as f :
20- self .train_data = dill .load (f )
21- self .tok_train = []
21+ tmp = dill .load (f )
22+ self .train_data = tmp [0 ]
23+ self .tok_train = tmp [1 ]
2224
2325 def save_train_data (self , data , file_name ):
2426 if not isinstance (data , list ):
@@ -28,14 +30,6 @@ def save_train_data(self, data, file_name):
2830 with open (file_name , 'wb' ) as f :
2931 dill .dump (data , f )
3032
31- async def tokenize_train_data (self ):
32- for q in self .train_data :
33- tok = await self .entity_wrapper .tokenize (
34- q [0 ],
35- filter_ents = self .filter_entities ,
36- sw_size = self .stopword_size )
37- self .tok_train .append (tok )
38-
3933 async def get_string_match (self , q , subset_idx = None ,
4034 all_larger_zero = False ):
4135 self .logger .info ("searching for word matches" )
@@ -49,17 +43,9 @@ async def get_string_match(self, q, subset_idx=None,
4943 tok_q = await self .entity_wrapper .tokenize (
5044 q , filter_ents = self .filter_entities , sw_size = self .stopword_size )
5145
52- # search for intent-like entities first
53- if "@" in q :
54- match_probas = [
55- 1.0 if "@" in t [0 ] else 0.0 for t in self .train_data
56- ]
57- # otherwise do string match
58- else :
59- match_probas = [
60- self .__jaccard_similarity (tok_q , t )
61- if '@' not in ' ' .join (t ) else 0.0 for t in tok_train
62- ]
46+ match_probas = [
47+ self .__jaccard_similarity (tok_q , t ) for t in tok_train
48+ ]
6349
6450 self .logger .info ("match_probas: {}" .format (match_probas ))
6551 max_proba = max (match_probas )
0 commit comments