@@ -119,6 +119,7 @@ def _set_special_token(self, typ: str, tid: Any) -> None:
119119 logger .warning (f'Special token type { typ } , id { tid } out of range, must be under { self .n_vocab } - skipping' )
120120
121121 def _try_load_from_tokenizer_json (self , path : Path ) -> bool :
122+ tokenizer = None
122123 tokenizer_file = path / 'tokenizer.json'
123124 if tokenizer_file .is_file ():
124125 with open (tokenizer_file , encoding = 'utf-8' ) as f :
@@ -152,11 +153,87 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
152153 added_tokens = tokenizer .get ('added_tokens' , {})
153154 else :
154155 added_tokens = {}
156+ tokenizer_config = None
155157 tokenizer_config_file = path / 'tokenizer_config.json'
156- if not tokenizer_config_file .is_file ():
158+ if tokenizer_config_file .is_file ():
159+ with open (tokenizer_config_file , encoding = 'utf-8' ) as f :
160+ tokenizer_config = json .load (f )
161+ if tokenizer :
162+ special_bos = (tokenizer_config or {}).get ('bos_token' )
163+ special_cls = (tokenizer_config or {}).get ('cls_token' )
164+ special_eos = (tokenizer_config or {}).get ('eos_token' )
165+ special_sep = (tokenizer_config or {}).get ('sep_token' )
166+ if not special_bos and special_cls and tokenizer_config :
167+ tokenizer_config ['bos_token' ] = special_bos = special_cls
168+ if not special_eos and special_sep and tokenizer_config :
169+ tokenizer_config ['eos_token' ] = special_eos = special_sep
170+ post_processor = tokenizer .get ('post_processor' , {})
171+ for processor in post_processor .get ('processors' , [post_processor ]):
172+ if processor .get ('type' ) == 'RobertaProcessing' :
173+ self .add_special_token ['bos' ] = True
174+ self .add_special_token ['eos' ] = True
175+ self .add_special_token ['sep' ] = True
176+ if not special_cls and tokenizer_config :
177+ special_cls = processor .get ('cls' , [special_bos ])[0 ]
178+ tokenizer_config ['cls_token' ] = special_cls
179+ if not special_sep and tokenizer_config :
180+ special_sep = processor .get ('sep' , [special_eos ])[0 ]
181+ tokenizer_config ['sep_token' ] = special_sep
182+ continue
183+ # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184+ # Only works with simple templates, **will** get it wrong on unusual sequences
185+ if processor .get ('type' ) == 'TemplateProcessing' :
186+ tmpl_single = processor .get ('single' , [])
187+ tmpl_pair = processor .get ('pair' , [])
188+ special_first = None
189+ special_last = None
190+ if len (tmpl_single ) > 1 :
191+ if special_first := tmpl_single [0 ].get ('SpecialToken' , {}).get ('id' ):
192+ if not tokenizer_config :
193+ special_bos = special_first
194+ self .add_special_token ['bos' ] = True if special_first in (special_bos , special_cls ) else False
195+ if special_first not in (special_bos , special_cls ):
196+ logger .warning (f'Unknown leading special token { special_first !r} in TemplateProcessing<single>' )
197+ if special_last := tmpl_single [- 1 ].get ('SpecialToken' , {}).get ('id' ):
198+ if not tokenizer_config :
199+ special_eos = special_last
200+ self .add_special_token ['eos' ] = True if special_last == special_eos else False
201+ if special_last != special_eos :
202+ logger .warning (f'Unknown trailing special token { special_last !r} in TemplateProcessing<single>' )
203+ if tmpl_pair :
204+ seq_start = 1 if tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ) == special_first else 0
205+ seq_stop = - 1 if tmpl_pair [- 1 ].get ('SpecialToken' , {}).get ('id' ) == special_last else None
206+ if seq_start == 0 or seq_stop is None :
207+ logger .warning ('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>' )
208+ if tmpl_pair := tmpl_pair [slice (seq_start , seq_stop )]:
209+ tmpl_a = tmpl_pair [0 ].get ('Sequence' , {}).get ('id' )
210+ tmpl_b = tmpl_pair [- 1 ].get ('Sequence' , {}).get ('id' )
211+ if tmpl_a != 'A' or tmpl_b != 'B' :
212+ logger .warning (f'Unknown sequence { tmpl_a } ...{ tmpl_b } in TemplateProcessing<pair>' )
213+ # A [sep] [eos] B
214+ if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair [1 :- 1 ]):
215+ add_sep = False
216+ if special_entry := tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ):
217+ if special_entry in (special_sep , special_eos ) and not special_last :
218+ add_sep = True
219+ if special_entry not in (special_sep , special_eos ):
220+ logger .warning (f'Unknown separator token { special_entry !r} in TemplateProcessing<pair>' )
221+ else :
222+ logger .warning (f'Unknown middle sequence { tmpl_pair [0 ]!r} in TemplateProcessing<pair>' )
223+ if len (tmpl_pair ) == 2 :
224+ if special_entry := tmpl_pair [1 ].get ('SpecialToken' , {}).get ('id' ):
225+ if special_entry in (special_sep , special_eos ):
226+ add_sep = True
227+ if special_entry not in (special_sep , special_eos ):
228+ logger .warning (f'Unknown second separator token { special_entry !r} in TemplateProcessing<pair>' )
229+ else :
230+ logger .warning (f'Unknown second middle sequence { tmpl_pair [1 ]!r} in TemplateProcessing<pair>' )
231+ self .add_special_token ['sep' ] = add_sep
232+ if add_sep and not special_sep and tokenizer_config :
233+ tokenizer_config ['sep_token' ] = special_eos
234+ continue
235+ if not tokenizer_config :
157236 return True
158- with open (tokenizer_config_file , encoding = 'utf-8' ) as f :
159- tokenizer_config = json .load (f )
160237 chat_template_alt = None
161238 chat_template_file = path / 'chat_template.json'
162239 if chat_template_file .is_file ():
0 commit comments