Skip to content

Commit

Permalink
VSFT hotfix - adds gen prompt to template and processor to hub (#1532)
Browse files Browse the repository at this point in the history
* adds gen prompt to template and processor to hub

* fixes hub model id, removes Path
  • Loading branch information
edbeeching authored Apr 12, 2024
1 parent 363369a commit 1c0d8bc
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/scripts/vsft_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from rich.logging import RichHandler

import torch
from accelerate import Accelerator
from datasets import load_dataset

from tqdm.rich import tqdm
Expand Down Expand Up @@ -111,7 +112,7 @@
################
# Model, Tokenizer & Processor
################
LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""

torch_dtype = (
model_config.torch_dtype
Expand Down Expand Up @@ -205,3 +206,5 @@ def __call__(self, examples):
with save_context:
trainer.save_model(training_args.output_dir)
trainer.push_to_hub()
if Accelerator().is_main_process:
processor.push_to_hub(training_args.hub_model_id)

0 comments on commit 1c0d8bc

Please sign in to comment.