@@ -120,11 +120,11 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
120120 self ._set_start_from (input_ids )
121121
122122 input_ids = torch .cat ((input_ids , next_token .unsqueeze (0 )), dim = - 1 )
123- partial_code , remainder_bytes = self ._get_partial_codes (input_ids )[0 ]
123+ partial_output , remainder_bytes = self ._get_partial_outputs (input_ids )[0 ]
124124
125- res , skip = self ._parse_partial_code (
125+ res , skip = self ._parse_partial_output (
126126 idx = 0 ,
127- partial_code = partial_code ,
127+ partial_output = partial_output ,
128128 remainder_bytes = remainder_bytes ,
129129 accepted_generation = False
130130 )
@@ -142,7 +142,7 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
142142 is_valid = self .dfa_mask_store .is_valid_prefix (res )
143143
144144 if is_valid :
145- self ._update_valid_state (partial_code , 0 , res )
145+ self ._update_valid_state (partial_output , 0 , res )
146146
147147 return is_valid
148148
@@ -163,11 +163,11 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->
163163 torch.FloatTensor: The masked scores.
164164 """
165165 self ._set_start_from (input_ids ) # start_from is used for choosing where the parsing should start
166- partial_codes = self ._get_partial_codes (input_ids )
166+ partial_outputs = self ._get_partial_outputs (input_ids )
167167
168- for idx , (partial_code , remainder_bytes ) in enumerate (partial_codes ):
168+ for idx , (partial_output , remainder_bytes ) in enumerate (partial_outputs ):
169169 # 1. Parsing
170- res , skip = self ._parse_partial_code (idx , partial_code , remainder_bytes , accepted_generation = True )
170+ res , skip = self ._parse_partial_output (idx , partial_output , remainder_bytes , accepted_generation = True )
171171 if skip : continue
172172
173173 # 2. Computing the accept mask
@@ -187,23 +187,29 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->
187187
188188 return scores
189189
190- def _parse_partial_code (self , idx : int , partial_code : str , remainder_bytes : bytes , accepted_generation = True ) -> tuple [ParseResult , bool ]:
190+ def _parse_partial_output (
191+ self ,
192+ idx : int ,
193+ partial_output : str ,
194+ remainder_bytes : bytes ,
195+ accepted_generation = True
196+ ) -> tuple [ParseResult , bool ]:
191197 """
192198 Parse the partial code and return the result.
193199 """
194200 skip = False
195201 res = None
196202
197203 try :
198- res = self .inc_parser .get_acceptable_next_terminals (partial_code )
204+ res = self .inc_parser .get_acceptable_next_terminals (partial_output )
199205
200206 if len (remainder_bytes ) > 0 :
201207 res .remainder_state = RemainderState .INCOMPLETE
202208 res .remainder = res .remainder .encode ('utf-8' ) + remainder_bytes
203209 else :
204210 res .remainder = res .remainder .encode ('utf-8' )
205211
206- self ._update_valid_state (partial_code , idx , res )
212+ self ._update_valid_state (partial_output , idx , res )
207213 except Exception as e :
208214 if self .dev_mode == True and accepted_generation :
209215 logger .info ("-" * 50 )
@@ -213,45 +219,45 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte
213219 elif self .parse_failed == False and accepted_generation :
214220 self .parse_failed = True
215221 logger .info ("-" * 50 )
216- logger .info (f"Parsing failed! Falling back to unconstrained decoding.\n Exception: { e } \n Partial code: { partial_code } \n Parsed lexical tokens: { self .inc_parser .parsed_lexer_tokens } " )
222+ logger .info (f"Parsing failed! Falling back to unconstrained decoding.\n Exception: { e } \n Partial code: { partial_output } \n Parsed lexical tokens: { self .inc_parser .parsed_lexer_tokens } " )
217223 logger .info ("-" * 50 )
218224 skip = True
219225 return res , skip
220226
221- def _get_partial_codes (self , input_ids : torch .LongTensor ) -> list [(str , bytes )]:
227+ def _get_partial_outputs (self , input_ids : torch .LongTensor ) -> list [(str , bytes )]:
222228 """
223229 Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string.
224230 """
225231 output = []
226232 for idx in range (len (input_ids )):
227233 if self .parse_output_only :
228- partial_code , remainder_bytes = self ._bytes_to_string (
234+ partial_output , remainder_bytes = self ._bytes_to_string (
229235 self .byte_tokenizer .decode (
230236 input_ids [idx , self .start_from :].tolist (), skip_special_tokens = True )
231237 )
232238 else :
233- partial_code , remainder_bytes = self ._bytes_to_string (
239+ partial_output , remainder_bytes = self ._bytes_to_string (
234240 self .byte_tokenizer .decode (
235241 input_ids [idx ].tolist (), skip_special_tokens = True )
236242 )
237- output .append ((partial_code , remainder_bytes ))
243+ output .append ((partial_output , remainder_bytes ))
238244 return output
239245
240- def _update_valid_state (self , partial_code : str , idx : int , r : ParseResult ):
246+ def _update_valid_state (self , partial_output : str , idx : int , r : ParseResult ):
241247 """
242248 This a simple heuristic to cut off the generated output at the end of the function.
243249 TODO: Put this under a flag to enable/disable this heuristic.
244250 """
245251 if idx < len (self .function_ends ):
246252 if r .function_end : # If the function end is not None, then the last valid state is the function end
247253 if self .function_ends [idx ] is None : self .function_ends [idx ] = []
248- self .function_ends [idx ].append (len (partial_code ) - len (r .remainder ))
254+ self .function_ends [idx ].append (len (partial_output ) - len (r .remainder ))
249255
250256 if idx < len (self .last_valid_state ):
251257 for accept_seq in r .accept_sequences :
252258 # 'EOF' is special terminal since $END does not work with python
253259 if accept_seq [0 ] == '$END' or accept_seq [0 ] == 'EOF' :
254- self .last_valid_state [idx ] = len (partial_code ) - len (r .remainder )
260+ self .last_valid_state [idx ] = len (partial_output ) - len (r .remainder )
255261
256262 @staticmethod
257263 def _bytes_to_string (byte_sequence : bytes ) -> tuple [str , bytes ]:
0 commit comments