Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds VLM Training support to SFTTrainer + VSFT script #1518

Merged
merged 23 commits into from
Apr 11, 2024

Conversation

edbeeching
Copy link
Collaborator

@edbeeching edbeeching commented Apr 10, 2024

  • Modifies the SFTTrainer so that it can be used with a VLM SFT dataset
  • Adds an example script vsft.py to train the llava1.5 model with an instruct dataset

TODO:

  • Test PEFT support
  • Run full training
  • Add example to docs
  • add tests

Example usage:

python examples/scripts/vsft.py \
    --model_name_or_path="llava-hf/llava-1.5-7b-hf" \
    --report_to="wandb" \
    --learning_rate=1.4e-5 \
    --per_device_train_batch_size=8 \
    --gradient_accumulation_steps=1 \
    --output_dir="data/vsft-llava-1.5-7b-hf" \
    --logging_steps=5 \
    --num_train_epochs=1 \
    --push_to_hub \
    --gradient_checkpointing \
    --remove_unused_columns=False \
    --torch_dtype=float16 \
    --fp16=True \ 
    --dataset_name=HuggingFaceH4/llava-instruct-mix-vsft \

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@edbeeching edbeeching requested a review from lewtun April 10, 2024 14:35
@edbeeching edbeeching marked this pull request as ready for review April 10, 2024 14:37
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean implementation @edbeeching ! Just a few nits and a few unit/integration tests needed and this should be good to merge

examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
):
if dataset is None:
raise ValueError("The dataset should not be None")

# check if torch dataset / dataloader and do nothing
if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)):
if skip_prepare_dataset or isinstance(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add an integration test for this case under the existing SFTTrainer tests, along with an example training a tiny random llava model

examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great!

examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
examples/scripts/vsft.py Outdated Show resolved Hide resolved
--use_peft \
--lora_r=64 \
--lora_alpha=16
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
# to evaluate, first install the lmms-eval framework: pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
# then run:
accelerate launch --num_processes=8 -m lmms_eval \
--model llava_hf \
--model_args pretrained=llava-hf/llava-1.5-7b-hf \
--tasks mmbench \
--batch_size 1 \
--output_path ./logs/ \
--log_sample
"""

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! We should be good to go after fixing the merge conflict on main ! 🚀

@younesbelkada younesbelkada merged commit 346c99d into main Apr 11, 2024
9 checks passed
@younesbelkada younesbelkada deleted the vlm-sft-support branch April 11, 2024 13:36
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* adds option to skip dataset preparation in SFTTrainer

* before changing the template

* adds support for new schema

* a few fixes to data collator to support new schema

* updates args

* precommit

* adds sys prompt to chat template and other fixes

* updates template, fixes collator for multiple images

* precommit

* rename vsft to vstf_llava

* adding integration tests

* adds integration test for vsft

* precommit

* adds back chat template

* docs

* typo

* adds eval, precommit

* adds peft launch args

* formatting

* fixes no deps tests by checking if PIL lib exists

* Update __init__.py

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants