55
66from .choice import Choice
77from .find_choices_options import FindChoicesOptions , FindValuesOptions
8+ from .found_choice import FoundChoice
89from .found_value import FoundValue
910from .model_result import ModelResult
1011from .sorted_value import SortedValue
1112from .token import Token
1213from .tokenizer import Tokenizer
1314
15+
1416class Find :
1517 """ Contains methods for matching user input against a list of choices """
16-
18+
1719 @staticmethod
1820 def find_choices (
19- utterance : str ,
20- choices : [ Union [str , Choice ] ],
21- options : FindChoicesOptions = None
21+ utterance : str ,
22+ choices : [Union [str , Choice ]],
23+ options : FindChoicesOptions = None
2224 ):
2325 """ Matches user input against a list of choices """
2426
2527 if not choices :
2628 raise TypeError ('Find: choices cannot be None. Must be a [str] or [Choice].' )
27-
29+
2830 opt = options if options else FindChoicesOptions ()
2931
3032 # Normalize list of choices
31- choices_list = [ Choice (value = choice ) if isinstance (choice , str ) else choice for choice in choices ]
33+ choices_list = [Choice (value = choice ) if isinstance (choice , str ) else choice for choice in choices ]
3234
3335 # Build up full list of synonyms to search over.
3436 # - Each entry in the list contains the index of the choice it belongs to which will later be
@@ -39,41 +41,57 @@ def find_choices(
3941 choice = choices_list [index ]
4042
4143 if not opt .no_value :
42- synonyms .append ( SortedValue (value = choice .value , index = index ) )
43-
44+ synonyms .append (SortedValue (value = choice .value , index = index ))
45+
4446 if (
45- getattr (choice , 'action' , False ) and
46- getattr (choice .action , 'title' , False ) and
47- not opt .no_value
47+ getattr (choice , 'action' , False ) and
48+ getattr (choice .action , 'title' , False ) and
49+ not opt .no_value
4850 ):
49- synonyms .append ( SortedValue (value = choice .action .title , index = index ) )
50-
51- if choice .synonyms != None :
51+ synonyms .append (SortedValue (value = choice .action .title , index = index ))
52+
53+ if choice .synonyms is not None :
5254 for synonym in synonyms :
53- synonyms .append ( SortedValue (value = synonym , index = index ) )
54-
55+ synonyms .append (SortedValue (value = synonym , index = index ))
56+
57+ def found_choice_constructor (value_model : ModelResult ) -> ModelResult :
58+ choice = choices_list [value_model .resolution .index ]
59+
60+ return ModelResult (
61+ start = value_model .start ,
62+ end = value_model .end ,
63+ type_name = 'choice' ,
64+ text = value_model .text ,
65+ resolution = FoundChoice (
66+ value = choice .value ,
67+ index = value_model .resolution .index ,
68+ score = value_model .resolution .score ,
69+ synonym = value_model .resolution .value ,
70+ )
71+ )
72+
5573 # Find synonyms in utterance and map back to their choices_list
56- return Find ._find_values (utterance , synonyms , options )
57-
74+ return list ( map ( found_choice_constructor , Find .find_values (utterance , synonyms , options )) )
75+
5876 @staticmethod
59- def _find_values (
60- utterance : str ,
61- values : List [SortedValue ],
62- options : FindValuesOptions = None
77+ def find_values (
78+ utterance : str ,
79+ values : List [SortedValue ],
80+ options : FindValuesOptions = None
6381 ):
6482 # Sort values in descending order by length, so that the longest value is searchd over first.
6583 sorted_values = sorted (
6684 values ,
67- key = lambda sorted_val : len (sorted_val .value ),
68- reverse = True
85+ key = lambda sorted_val : len (sorted_val .value ),
86+ reverse = True
6987 )
7088
7189 # Search for each value within the utterance.
7290 matches : [ModelResult ] = []
7391 opt = options if options else FindValuesOptions ()
7492 tokenizer : Callable [[str , str ], List [Token ]] = opt .tokenizer if opt .tokenizer else Tokenizer .default_tokenizer
7593 tokens = tokenizer (utterance , opt .locale )
76- max_distance = opt .max_token_distance if opt .max_token_distance != None else 2
94+ max_distance = opt .max_token_distance if opt .max_token_distance is not None else 2
7795
7896 for i in range (len (sorted_values )):
7997 entry = sorted_values [i ]
@@ -95,18 +113,18 @@ def _find_values(
95113 searched_tokens ,
96114 start_pos
97115 )
98-
99- if match != None :
116+
117+ if match is not None :
100118 start_pos = match .end + 1
101119 matches .append (match )
102120 else :
103121 break
104-
122+
105123 # Sort matches by score descending
106124 sorted_matches = sorted (
107125 matches ,
108- key = lambda model_result : model_result .resolution .score ,
109- reverse = True
126+ key = lambda model_result : model_result .resolution .score ,
127+ reverse = True
110128 )
111129
112130 # Filter out duplicate matching indexes and overlapping characters
@@ -125,7 +143,7 @@ def _find_values(
125143 if i in used_tokens :
126144 add = False
127145 break
128-
146+
129147 # Add to results
130148 if add :
131149 # Update filter info
@@ -137,21 +155,21 @@ def _find_values(
137155 # Translate start & end and populate text field
138156 match .start = tokens [match .start ].start
139157 match .end = tokens [match .end ].end
140- match .text = utterance [match .start : match .end + 1 ]
158+ match .text = utterance [match .start : match .end + 1 ]
141159 results .append (match )
142-
160+
143161 # Return the results sorted by position in the utterance
144- return sorted (results , key = lambda model_result : model_result .start )
162+ return sorted (results , key = lambda model_result : model_result .start )
145163
146164 @staticmethod
147165 def _match_value (
148- source_tokens : List [Token ],
149- max_distance : int ,
150- options : FindValuesOptions ,
151- index : int ,
152- value : str ,
153- searched_tokens : List [Token ],
154- start_pos : int
166+ source_tokens : List [Token ],
167+ max_distance : int ,
168+ options : FindValuesOptions ,
169+ index : int ,
170+ value : str ,
171+ searched_tokens : List [Token ],
172+ start_pos : int
155173 ) -> Union [ModelResult , None ]:
156174 # Match value to utterance and calculate total deviation.
157175 # - The tokens are matched in order so "second last" will match in
@@ -180,16 +198,16 @@ def _match_value(
180198 # Update start & end position that will track the span of the utterance that's matched.
181199 if (start < 0 ):
182200 start = pos
183-
201+
184202 end = pos
185-
203+
186204 # Calculate score and format result
187205 # - The start & end positions and the results text field will be corrected by the caller.
188206 result : ModelResult = None
189207
190208 if (
191- matched > 0 and
192- (matched == len (searched_tokens ) or options .allow_partial_matches )
209+ matched > 0 and
210+ (matched == len (searched_tokens ) or options .allow_partial_matches )
193211 ):
194212 # Percentage of tokens matched. If matching "second last" in
195213 # "the second form last one" the completeness would be 1.0 since
@@ -207,28 +225,27 @@ def _match_value(
207225
208226 # Format result
209227 result = ModelResult (
210- text = '' ,
211- start = start ,
212- end = end ,
213- type_name = "value" ,
214- resolution = FoundValue (
215- value = value ,
216- index = index ,
217- score = score
228+ text = '' ,
229+ start = start ,
230+ end = end ,
231+ type_name = "value" ,
232+ resolution = FoundValue (
233+ value = value ,
234+ index = index ,
235+ score = score
218236 )
219237 )
220-
238+
221239 return result
222-
240+
223241 @staticmethod
224242 def _index_of_token (
225- tokens : List [Token ],
226- token : Token ,
227- start_pos : int
243+ tokens : List [Token ],
244+ token : Token ,
245+ start_pos : int
228246 ) -> int :
229247 for i in range (start_pos , len (tokens )):
230248 if tokens [i ].normalized == token .normalized :
231249 return i
232-
233- return - 1
234250
251+ return - 1
0 commit comments