Skip to content

Commit

Permalink
protection
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Jan 8, 2025
1 parent a0fcc33 commit 07ad82b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model,
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
# avoid padding in front of tokens
if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
if hasattr(tokenizer, '_pad_token') and tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
if tokenizer._unk_token: # use hidden variable to avoid annoying properly logger bug
if hasattr(tokenizer, '_unk_token') and tokenizer._unk_token: # use hidden variable to avoid annoying properly logger bug
stop_words_ids = [x[1:] if x[0] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids]
stop_words_ids = [x[:-1] if x[-1] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids]
if tokenizer._eos_token: # use hidden variable to avoid annoying properly logger bug
if hasattr(tokenizer, '_eos_token') and tokenizer._eos_token: # use hidden variable to avoid annoying properly logger bug
stop_words_ids = [x[:-1] if x[-1] == tokenizer.eos_token_id and len(x) > 1 else x for x in stop_words_ids]
if tokenizer._bos_token: # use hidden variable to avoid annoying properly logger bug
if hasattr(tokenizer, '_bos_token') and tokenizer._bos_token: # use hidden variable to avoid annoying properly logger bug
stop_words_ids = [x[1:] if x[0] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids]
stop_words_ids = [x[:-1] if x[-1] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids]
if base_model and t5_type(base_model) and hasattr(tokenizer, 'vocab'):
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2e830cc79a0bb6a7044e0794fec0ba30f4063f0f"
__version__ = "a0fcc3344d53a834fe3cb5b26265aaeb84993b77"

0 comments on commit 07ad82b

Please sign in to comment.