diff --git a/src/stopping.py b/src/stopping.py index 44c56bad7..1980448ba 100644 --- a/src/stopping.py +++ b/src/stopping.py @@ -180,14 +180,14 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model, stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids] stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0] # avoid padding in front of tokens - if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug + if hasattr(tokenizer, '_pad_token') and tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids] - if tokenizer._unk_token: # use hidden variable to avoid annoying properly logger bug + if hasattr(tokenizer, '_unk_token') and tokenizer._unk_token: # use hidden variable to avoid annoying properly logger bug stop_words_ids = [x[1:] if x[0] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids] stop_words_ids = [x[:-1] if x[-1] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids] - if tokenizer._eos_token: # use hidden variable to avoid annoying properly logger bug + if hasattr(tokenizer, '_eos_token') and tokenizer._eos_token: # use hidden variable to avoid annoying properly logger bug stop_words_ids = [x[:-1] if x[-1] == tokenizer.eos_token_id and len(x) > 1 else x for x in stop_words_ids] - if tokenizer._bos_token: # use hidden variable to avoid annoying properly logger bug + if hasattr(tokenizer, '_bos_token') and tokenizer._bos_token: # use hidden variable to avoid annoying properly logger bug stop_words_ids = [x[1:] if x[0] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids] stop_words_ids = [x[:-1] if x[-1] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids] if base_model and t5_type(base_model) and hasattr(tokenizer, 'vocab'): diff --git a/src/version.py b/src/version.py index 06beddbba..f5361e9a4 100644 --- a/src/version.py +++ b/src/version.py @@ -1 +1 @@ -__version__ = "2e830cc79a0bb6a7044e0794fec0ba30f4063f0f" +__version__ = "a0fcc3344d53a834fe3cb5b26265aaeb84993b77"