generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds VLM Training support to SFTTrainer + VSFT script #1518
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
43b78e2
adds option to skip dataset preparation in SFTTrainer
edbeeching d98b66d
before changing the template
edbeeching 142727a
adds support for new schema
edbeeching c683d89
a few fixes to data collator to support new schema
edbeeching 4c0457d
updates args
edbeeching 2bec041
precommit
edbeeching 57d1ad8
adds sys prompt to chat template and other fixes
edbeeching 26f84e8
updates template, fixes collator for multiple images
edbeeching 2e7ad90
precommit
edbeeching 8a55580
rename vsft to vstf_llava
edbeeching 9d4d732
adding integration tests
edbeeching c6a7c6f
adds integration test for vsft
edbeeching 798799a
precommit
edbeeching c65c43d
adds back chat template
edbeeching dad4eda
docs
edbeeching 89e268e
typo
edbeeching 1d81cc8
adds eval, precommit
edbeeching 35dd6a3
adds peft launch args
edbeeching 46acce9
Merge branch 'main' into vlm-sft-support
edbeeching 09fe59c
Merge branch 'main' into vlm-sft-support
edbeeching da06efa
formatting
edbeeching 36f942a
fixes no deps tests by checking if PIL lib exists
edbeeching 9c1bf9a
Update __init__.py
younesbelkada File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# flake8: noqa | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
# regular: | ||
python examples/scripts/vsft.py \ | ||
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \ | ||
--report_to="wandb" \ | ||
--learning_rate=1.4e-5 \ | ||
--per_device_train_batch_size=8 \ | ||
--gradient_accumulation_steps=1 \ | ||
--output_dir="data/vsft-llava-1.5-7b-hf" \ | ||
--logging_steps=5 \ | ||
--num_train_epochs=1 \ | ||
--push_to_hub \ | ||
--gradient_checkpointing \ | ||
--remove_unused_columns=False \ | ||
--torch_dtype=float16 \ | ||
--fp16=True \ | ||
--dataset_name=HuggingFaceH4/llava-instruct-mix-vsft \ | ||
|
||
# peft: | ||
python examples/scripts/vsft.py \ | ||
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \ | ||
--report_to="wandb" \ | ||
--learning_rate=1.4e-5 \ | ||
--per_device_train_batch_size=8 \ | ||
--gradient_accumulation_steps=1 \ | ||
--output_dir="data/vsft-llava-1.5-7b-hf" \ | ||
--logging_steps=5 \ | ||
--num_train_epochs=1 \ | ||
--push_to_hub \ | ||
--gradient_checkpointing \ | ||
--remove_unused_columns=False \ | ||
--torch_dtype=float16 \ | ||
--fp16=True \ | ||
--dataset_name=HuggingFaceH4/llava-instruct-mix-vsft \ | ||
--use_peft=True \ | ||
--lora_r=64 \ | ||
--lora_alpha=16 \ | ||
--lora_target_modules=all-linear" | ||
|
||
# evaluation: | ||
|
||
To evaluate, first install the lmms-eval framework: pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git | ||
then run: | ||
accelerate launch --num_processes=8 -m lmms_eval \ | ||
--model llava_hf \ | ||
--model_args pretrained=llava-hf/llava-1.5-7b-hf \ | ||
--tasks mmbench \ | ||
--batch_size 1 \ | ||
--output_path ./logs/ \ | ||
--log_sample | ||
""" | ||
import logging | ||
import os | ||
from contextlib import nullcontext | ||
|
||
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) | ||
|
||
from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser | ||
|
||
if TRL_USE_RICH: | ||
init_zero_verbose() | ||
FORMAT = "%(message)s" | ||
|
||
from rich.console import Console | ||
from rich.logging import RichHandler | ||
|
||
import torch | ||
from datasets import load_dataset | ||
|
||
from tqdm.rich import tqdm | ||
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration | ||
|
||
from trl import ( | ||
ModelConfig, | ||
RichProgressCallback, | ||
SFTTrainer, | ||
get_peft_config, | ||
get_quantization_config, | ||
get_kbit_device_map, | ||
) | ||
|
||
tqdm.pandas() | ||
|
||
if TRL_USE_RICH: | ||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig)) | ||
args, training_args, model_config = parser.parse_args_and_config() | ||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) | ||
# Force use our print callback | ||
if TRL_USE_RICH: | ||
training_args.disable_tqdm = True | ||
console = Console() | ||
|
||
################ | ||
# Model, Tokenizer & Processor | ||
################ | ||
LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}""" | ||
|
||
torch_dtype = ( | ||
model_config.torch_dtype | ||
if model_config.torch_dtype in ["auto", None] | ||
else getattr(torch, model_config.torch_dtype) | ||
) | ||
quantization_config = get_quantization_config(model_config) | ||
model_kwargs = dict( | ||
revision=model_config.model_revision, | ||
trust_remote_code=model_config.trust_remote_code, | ||
attn_implementation=model_config.attn_implementation, | ||
torch_dtype=torch_dtype, | ||
device_map=get_kbit_device_map() if quantization_config is not None else None, | ||
quantization_config=quantization_config, | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True) | ||
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE | ||
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path) | ||
processor.tokenizer = tokenizer | ||
|
||
model = LlavaForConditionalGeneration.from_pretrained(model_config.model_name_or_path, **model_kwargs) | ||
|
||
################ | ||
# Create a data collator to encode text and image pairs | ||
################ | ||
|
||
class LLavaDataCollator: | ||
def __init__(self, processor): | ||
self.processor = processor | ||
|
||
def __call__(self, examples): | ||
texts = [] | ||
images = [] | ||
for example in examples: | ||
if len(example["images"]) > 1: | ||
raise ValueError("This collator only supports one image per example") | ||
messages = example["messages"] | ||
text = self.processor.tokenizer.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=False | ||
) | ||
texts.append(text) | ||
images.append(example["images"][0]) | ||
|
||
batch = self.processor(texts, images, return_tensors="pt", padding=True) | ||
|
||
labels = batch["input_ids"].clone() | ||
if self.processor.tokenizer.pad_token_id is not None: | ||
labels[labels == self.processor.tokenizer.pad_token_id] = -100 | ||
batch["labels"] = labels | ||
|
||
return batch | ||
|
||
data_collator = LLavaDataCollator(processor) | ||
|
||
################ | ||
# Dataset | ||
################ | ||
raw_datasets = load_dataset(args.dataset_name) | ||
train_dataset = raw_datasets["train"] | ||
eval_dataset = raw_datasets["test"] | ||
|
||
################ | ||
# Optional rich context managers | ||
############### | ||
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...") | ||
save_context = ( | ||
nullcontext() | ||
if not TRL_USE_RICH | ||
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}") | ||
) | ||
|
||
################ | ||
# Training | ||
################ | ||
with init_context: | ||
trainer = SFTTrainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
dataset_text_field="text", # need a dummy field | ||
tokenizer=tokenizer, | ||
peft_config=get_peft_config(model_config), | ||
callbacks=[RichProgressCallback] if TRL_USE_RICH else None, | ||
data_collator=data_collator, | ||
dataset_kwargs={"skip_prepare_dataset": True}, | ||
) | ||
|
||
trainer.train() | ||
|
||
with save_context: | ||
trainer.save_model(training_args.output_dir) | ||
trainer.push_to_hub() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.