-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Online DPO and Online trainer refactor #1809
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
def batch_generation( | ||
model: torch.nn.Module, | ||
queries: torch.Tensor, | ||
local_rollout_forward_batch_size: int, | ||
pad_token_id: int, | ||
generation_config: dict, | ||
): | ||
query_responses = [] | ||
logitss = [] | ||
for i in range(0, queries.shape[0], local_rollout_forward_batch_size): | ||
query = queries[i : i + local_rollout_forward_batch_size] | ||
query_response, logits = generate( | ||
model, | ||
query, | ||
pad_token_id, | ||
generation_config, | ||
) | ||
query_responses.append(query_response) | ||
logitss.append(logits) | ||
return torch.cat(query_responses, 0), torch.cat(logitss, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I abstracted the batch_generation
logic into a separate method that the online trainers all use. In the future, we can implement it to support vLLM generation.
|
trl/trainer/online_dpo_trainer.py
Outdated
args.local_mini_batch_size = exact_div( | ||
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" | ||
) | ||
args.num_updates = args.total_episodes // args.batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if batch_size
doesn't exactly divide total_episodes
then the run doesn't actually go for total_episodes
and the logic below for num_train_epochs
(and therefore other things like saving) ends up being wrong.
I'd suggest changing num_updates
to something like num_total_batches
which better reflects its actual meaning and using exact_div
to make sure it divides correctly. Otherwise, you could support not exact division but this seems like a hassle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exact_div
is probably too stringent. What do you think of math.ceil(args.total_episodes // args.batch_size)
? num_total_batches
is a nice suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that works but then you need to trim the size of queries
in the last batch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe that's too stringent.
self.beta = config.beta | ||
self.loss_type = config.loss_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change from config
to args
for consistency
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
commit 8bd2ab8 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Sun Jul 28 14:06:19 2024 +0200 Refactor judges (#1856) * BaseJudge -> BasePairwiseJudge * hf judge asyncio * refactor judges * doc * doc * doc * memeber judge * :inherited-members: * :inherited-members: * doc * give up * judge tldr with judge class * fix rank in multithread * format * improve doc * update doc * typo doc * doc online dpo * Update judge_tldr.py --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 82b07d6 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Jul 26 11:43:48 2024 +0200 Llama in modelling value head tests (#1878) commit 72bf6c2 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Jul 26 11:33:07 2024 +0200 Skip BigBird save and load test until next transformers version (#1874) commit 74e54b5 Author: Edward Beeching <edbeeching@users.noreply.github.com> Date: Fri Jul 26 09:36:25 2024 +0200 fix online dpo example (#1879) commit 3930973 Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com> Date: Thu Jul 25 14:17:37 2024 +0530 Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861) * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM Added ```dataset_text_field``` in the SFTConfig while training * Update docs/source/sft_trainer.mdx --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit db8e09e Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com> Date: Thu Jul 25 14:06:57 2024 +0530 Import missing ```setup_chat_format``` (#1862) commit 1dae55f Author: elie <97572401+eliebak@users.noreply.github.com> Date: Thu Jul 25 10:27:34 2024 +0200 add fsdp_qlora config and bnb_4bit_quant_storage (#1863) commit c8cef79 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Jul 24 21:06:57 2024 +0200 arXiv to HF Papers (#1870) commit 7dcf437 Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Wed Jul 24 12:27:50 2024 +0200 [online-DPO] online dpo cleanups (#1864) * online dpo cleanups * remove unused self.policy * add OnlineDPOTrainer and config to __init__.py * import from trainer * online dpo test * rename policy to model and ref_policy to ref_model * renamed internally * formatting commit 4e85bd7 Author: Costa Huang <costa.huang@outlook.com> Date: Thu Jul 18 14:35:31 2024 -0400 Online DPO and Online trainer refactor (#1809) * online dpo trainer based on rloo trainer * push changes * refactor * use `batch_generation` method * precommit * remove breakpoint() * quick refactor * push the current changes * quick change * refactor * use the config name as the experiment name * fix logging * update online DPO docs * push docs * increment global step so tensorboard works again. * precommit * remove unused common online trainer * add online DPO docs * quick refactor * push changes * Update docs/source/online_dpo_trainer.md Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --------- Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit c9d5636 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Jul 18 18:28:49 2024 +0200 rm token (#1852)
commit 890232f Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Jul 30 14:29:47 2024 +0200 update example overview (#1883) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 9929370 Author: Clara Pohland <54847419+claralp@users.noreply.github.com> Date: Sun Jul 28 21:10:08 2024 +0200 Move BCO to separate BCOTrainer with fixes (#1869) * kto_trainer: skip KL data for BCO * kto_trainer: BCO allow no positives or no negatives in batch * kto_trainer: make RunningMoments object serializable * add BCOTrainer * fix BCO UDM for not interleaved data * kto_trainer: remove unused UDM part * bco_trainer: add tests and docs, minor fixes * code style fixes * Update docs/source/bco_trainer.mdx Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * fix BCO UDM for bfloat16 * Update trl/trainer/bco_config.py * Update trl/trainer/bco_config.py Co-authored-by: Seungjae Jung <seanexplode@gmail.com> * Update trl/trainer/utils.py Co-authored-by: Seungjae Jung <seanexplode@gmail.com> * Update trl/trainer/bco_trainer.py Co-authored-by: Seungjae Jung <seanexplode@gmail.com> * Update trl/trainer/bco_config.py * Update _toctree.yml * Update trl/trainer/bco_config.py * Update trl/trainer/bco_trainer.py * RunningMoments, fix multi GPU serialization * fix tests --------- Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Seungjae Jung <seanexplode@gmail.com> commit 6171cdd Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Sun Jul 28 15:51:38 2024 +0200 Re-add BigBird Pegasus save/load test (#1882) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 33d2151 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Sun Jul 28 15:07:10 2024 +0200 Re-add BigBird Pegasus save/load test (#1876) * skip bigbird in ci * readd big bird test * pytest parametrize * dont check the version * rm model name * re add big bird * Merge branch 'main' into readd-bigbird-save-load-test --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 8bd2ab8 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Sun Jul 28 14:06:19 2024 +0200 Refactor judges (#1856) * BaseJudge -> BasePairwiseJudge * hf judge asyncio * refactor judges * doc * doc * doc * memeber judge * :inherited-members: * :inherited-members: * doc * give up * judge tldr with judge class * fix rank in multithread * format * improve doc * update doc * typo doc * doc online dpo * Update judge_tldr.py --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 82b07d6 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Jul 26 11:43:48 2024 +0200 Llama in modelling value head tests (#1878) commit 72bf6c2 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Jul 26 11:33:07 2024 +0200 Skip BigBird save and load test until next transformers version (#1874) commit 74e54b5 Author: Edward Beeching <edbeeching@users.noreply.github.com> Date: Fri Jul 26 09:36:25 2024 +0200 fix online dpo example (#1879) commit 3930973 Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com> Date: Thu Jul 25 14:17:37 2024 +0530 Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861) * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM Added ```dataset_text_field``` in the SFTConfig while training * Update docs/source/sft_trainer.mdx --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit db8e09e Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com> Date: Thu Jul 25 14:06:57 2024 +0530 Import missing ```setup_chat_format``` (#1862) commit 1dae55f Author: elie <97572401+eliebak@users.noreply.github.com> Date: Thu Jul 25 10:27:34 2024 +0200 add fsdp_qlora config and bnb_4bit_quant_storage (#1863) commit c8cef79 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Jul 24 21:06:57 2024 +0200 arXiv to HF Papers (#1870) commit 7dcf437 Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Wed Jul 24 12:27:50 2024 +0200 [online-DPO] online dpo cleanups (#1864) * online dpo cleanups * remove unused self.policy * add OnlineDPOTrainer and config to __init__.py * import from trainer * online dpo test * rename policy to model and ref_policy to ref_model * renamed internally * formatting commit 4e85bd7 Author: Costa Huang <costa.huang@outlook.com> Date: Thu Jul 18 14:35:31 2024 -0400 Online DPO and Online trainer refactor (#1809) * online dpo trainer based on rloo trainer * push changes * refactor * use `batch_generation` method * precommit * remove breakpoint() * quick refactor * push the current changes * quick change * refactor * use the config name as the experiment name * fix logging * update online DPO docs * push docs * increment global step so tensorboard works again. * precommit * remove unused common online trainer * add online DPO docs * quick refactor * push changes * Update docs/source/online_dpo_trainer.md Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --------- Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit c9d5636 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Jul 18 18:28:49 2024 +0200 rm token (#1852)
* fix vsft example commands * fix use_cache and get tokenizer from processor * rm unused AutoTokenizer * Squashed commit of the following: commit 8bd2ab8 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Sun Jul 28 14:06:19 2024 +0200 Refactor judges (#1856) * BaseJudge -> BasePairwiseJudge * hf judge asyncio * refactor judges * doc * doc * doc * memeber judge * :inherited-members: * :inherited-members: * doc * give up * judge tldr with judge class * fix rank in multithread * format * improve doc * update doc * typo doc * doc online dpo * Update judge_tldr.py --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 82b07d6 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Jul 26 11:43:48 2024 +0200 Llama in modelling value head tests (#1878) commit 72bf6c2 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Jul 26 11:33:07 2024 +0200 Skip BigBird save and load test until next transformers version (#1874) commit 74e54b5 Author: Edward Beeching <edbeeching@users.noreply.github.com> Date: Fri Jul 26 09:36:25 2024 +0200 fix online dpo example (#1879) commit 3930973 Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com> Date: Thu Jul 25 14:17:37 2024 +0530 Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861) * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM Added ```dataset_text_field``` in the SFTConfig while training * Update docs/source/sft_trainer.mdx --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit db8e09e Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com> Date: Thu Jul 25 14:06:57 2024 +0530 Import missing ```setup_chat_format``` (#1862) commit 1dae55f Author: elie <97572401+eliebak@users.noreply.github.com> Date: Thu Jul 25 10:27:34 2024 +0200 add fsdp_qlora config and bnb_4bit_quant_storage (#1863) commit c8cef79 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Jul 24 21:06:57 2024 +0200 arXiv to HF Papers (#1870) commit 7dcf437 Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Wed Jul 24 12:27:50 2024 +0200 [online-DPO] online dpo cleanups (#1864) * online dpo cleanups * remove unused self.policy * add OnlineDPOTrainer and config to __init__.py * import from trainer * online dpo test * rename policy to model and ref_policy to ref_model * renamed internally * formatting commit 4e85bd7 Author: Costa Huang <costa.huang@outlook.com> Date: Thu Jul 18 14:35:31 2024 -0400 Online DPO and Online trainer refactor (#1809) * online dpo trainer based on rloo trainer * push changes * refactor * use `batch_generation` method * precommit * remove breakpoint() * quick refactor * push the current changes * quick change * refactor * use the config name as the experiment name * fix logging * update online DPO docs * push docs * increment global step so tensorboard works again. * precommit * remove unused common online trainer * add online DPO docs * quick refactor * push changes * Update docs/source/online_dpo_trainer.md Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --------- Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit c9d5636 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Jul 18 18:28:49 2024 +0200 rm token (#1852) * add section in doc * Squashed commit of the following: commit 890232f Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Jul 30 14:29:47 2024 +0200 update example overview (#1883) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 9929370 Author: Clara Pohland <54847419+claralp@users.noreply.github.com> Date: Sun Jul 28 21:10:08 2024 +0200 Move BCO to separate BCOTrainer with fixes (#1869) * kto_trainer: skip KL data for BCO * kto_trainer: BCO allow no positives or no negatives in batch * kto_trainer: make RunningMoments object serializable * add BCOTrainer * fix BCO UDM for not interleaved data * kto_trainer: remove unused UDM part * bco_trainer: add tests and docs, minor fixes * code style fixes * Update docs/source/bco_trainer.mdx Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * fix BCO UDM for bfloat16 * Update trl/trainer/bco_config.py * Update trl/trainer/bco_config.py Co-authored-by: Seungjae Jung <seanexplode@gmail.com> * Update trl/trainer/utils.py Co-authored-by: Seungjae Jung <seanexplode@gmail.com> * Update trl/trainer/bco_trainer.py Co-authored-by: Seungjae Jung <seanexplode@gmail.com> * Update trl/trainer/bco_config.py * Update _toctree.yml * Update trl/trainer/bco_config.py * Update trl/trainer/bco_trainer.py * RunningMoments, fix multi GPU serialization * fix tests --------- Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Seungjae Jung <seanexplode@gmail.com> commit 6171cdd Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Sun Jul 28 15:51:38 2024 +0200 Re-add BigBird Pegasus save/load test (#1882) Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 33d2151 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Sun Jul 28 15:07:10 2024 +0200 Re-add BigBird Pegasus save/load test (#1876) * skip bigbird in ci * readd big bird test * pytest parametrize * dont check the version * rm model name * re add big bird * Merge branch 'main' into readd-bigbird-save-load-test --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 8bd2ab8 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Sun Jul 28 14:06:19 2024 +0200 Refactor judges (#1856) * BaseJudge -> BasePairwiseJudge * hf judge asyncio * refactor judges * doc * doc * doc * memeber judge * :inherited-members: * :inherited-members: * doc * give up * judge tldr with judge class * fix rank in multithread * format * improve doc * update doc * typo doc * doc online dpo * Update judge_tldr.py --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> commit 82b07d6 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Jul 26 11:43:48 2024 +0200 Llama in modelling value head tests (#1878) commit 72bf6c2 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Jul 26 11:33:07 2024 +0200 Skip BigBird save and load test until next transformers version (#1874) commit 74e54b5 Author: Edward Beeching <edbeeching@users.noreply.github.com> Date: Fri Jul 26 09:36:25 2024 +0200 fix online dpo example (#1879) commit 3930973 Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com> Date: Thu Jul 25 14:17:37 2024 +0530 Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861) * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM Added ```dataset_text_field``` in the SFTConfig while training * Update docs/source/sft_trainer.mdx --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit db8e09e Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com> Date: Thu Jul 25 14:06:57 2024 +0530 Import missing ```setup_chat_format``` (#1862) commit 1dae55f Author: elie <97572401+eliebak@users.noreply.github.com> Date: Thu Jul 25 10:27:34 2024 +0200 add fsdp_qlora config and bnb_4bit_quant_storage (#1863) commit c8cef79 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Jul 24 21:06:57 2024 +0200 arXiv to HF Papers (#1870) commit 7dcf437 Author: Kashif Rasul <kashif.rasul@gmail.com> Date: Wed Jul 24 12:27:50 2024 +0200 [online-DPO] online dpo cleanups (#1864) * online dpo cleanups * remove unused self.policy * add OnlineDPOTrainer and config to __init__.py * import from trainer * online dpo test * rename policy to model and ref_policy to ref_model * renamed internally * formatting commit 4e85bd7 Author: Costa Huang <costa.huang@outlook.com> Date: Thu Jul 18 14:35:31 2024 -0400 Online DPO and Online trainer refactor (#1809) * online dpo trainer based on rloo trainer * push changes * refactor * use `batch_generation` method * precommit * remove breakpoint() * quick refactor * push the current changes * quick change * refactor * use the config name as the experiment name * fix logging * update online DPO docs * push docs * increment global step so tensorboard works again. * precommit * remove unused common online trainer * add online DPO docs * quick refactor * push changes * Update docs/source/online_dpo_trainer.md Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --------- Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit c9d5636 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Jul 18 18:28:49 2024 +0200 rm token (#1852) * simplify script * doc * use traning args * args instead of trianing args * fix doc * drop eval * rm eval section * re-add bigbirg --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
This PR supports online DPO per @mnoukhov's PR and also includes refactor on existing online trainers.