-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds VLM Training support to SFTTrainer + VSFT script #1518
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean implementation @edbeeching ! Just a few nits and a few unit/integration tests needed and this should be good to merge
trl/trainer/sft_trainer.py
Outdated
): | ||
if dataset is None: | ||
raise ValueError("The dataset should not be None") | ||
|
||
# check if torch dataset / dataloader and do nothing | ||
if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)): | ||
if skip_prepare_dataset or isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add an integration test for this case under the existing SFTTrainer
tests, along with an example training a tiny random llava model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great!
--use_peft \ | ||
--lora_r=64 \ | ||
--lora_alpha=16 | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | |
# to evaluate, first install the lmms-eval framework: pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git | |
# then run: | |
accelerate launch --num_processes=8 -m lmms_eval \ | |
--model llava_hf \ | |
--model_args pretrained=llava-hf/llava-1.5-7b-hf \ | |
--tasks mmbench \ | |
--batch_size 1 \ | |
--output_path ./logs/ \ | |
--log_sample | |
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! We should be good to go after fixing the merge conflict on main ! 🚀
* adds option to skip dataset preparation in SFTTrainer * before changing the template * adds support for new schema * a few fixes to data collator to support new schema * updates args * precommit * adds sys prompt to chat template and other fixes * updates template, fixes collator for multiple images * precommit * rename vsft to vstf_llava * adding integration tests * adds integration test for vsft * precommit * adds back chat template * docs * typo * adds eval, precommit * adds peft launch args * formatting * fixes no deps tests by checking if PIL lib exists * Update __init__.py --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
vsft.py
to train the llava1.5 model with an instruct datasetTODO:
Example usage: