Skip to content

Commit

Permalink
Properly set pad_token_id
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss committed Oct 20, 2024
1 parent 1f538e7 commit 1f1393b
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,22 @@ def setup_tokenizer(args, model, assistant_model):
tokenizer.padding_side = "left"

if model.config.model_type == "llama":
# unwind broken decapoda-research config
model.generation_config.pad_token_id = 0
model.generation_config.bos_token_id = 1
model.generation_config.eos_token_id = 2
if model.generation_config.pad_token_id is None:
if isinstance(model.generation_config.eos_token_id, int):
model.generation_config.pad_token_id = model.generation_config.eos_token_id
elif isinstance(model.generation_config.eos_token_id, list):
model.generation_config.pad_token_id = model.generation_config.eos_token_id[0]
if assistant_model is not None:
assistant_model.generation_config.pad_token_id = 0
assistant_model.generation_config.bos_token_id = 1
assistant_model.generation_config.eos_token_id = 2
if assistant_model.generation_config.pad_token_id is None:
if isinstance(assistant_model.generation_config.eos_token_id, int):
assistant_model.generation_config.pad_token_id = assistant_model.generation_config.eos_token_id
elif isinstance(assistant_model.generation_config.eos_token_id, list):
assistant_model.generation_config.pad_token_id = assistant_model.generation_config.eos_token_id[0]
tokenizer.bos_token_id = model.generation_config.bos_token_id
tokenizer.eos_token_id = model.generation_config.eos_token_id
if isinstance(model.generation_config.eos_token_id, int):
tokenizer.eos_token_id = model.generation_config.eos_token_id
elif isinstance(model.generation_config.eos_token_id, list):
tokenizer.eos_token_id = model.generation_config.eos_token_id[0]
tokenizer.pad_token_id = model.generation_config.pad_token_id
tokenizer.pad_token = tokenizer.decode(tokenizer.pad_token_id)
tokenizer.eos_token = tokenizer.decode(tokenizer.eos_token_id)
Expand Down

0 comments on commit 1f1393b

Please sign in to comment.