Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add minibatching #153

Merged
merged 20 commits into from
Feb 23, 2023
Merged

add minibatching #153

merged 20 commits into from
Feb 23, 2023

Conversation

lvwerra
Copy link
Member

@lvwerra lvwerra commented Feb 16, 2023

Until now the PPO mini batch size has been hardcoded to 1. This PR aims to change it by refactoring the forward/backward passing logic.

In summary this PR does the following things:

  1. The batched_forward_pass returns new a mask which can be used to mask parts of the sequence to be ignored
  2. enable mini-batching of PPO by creating a small dataloader with the mini_batch_size to sample from the current PPO batch
  3. In the loss method we replace all operations affected by masked parts of the sequence with masked ones (masked_mean, masked_whiten)
  4. remove compute_logits_vpred and use batched_forward_pass for everything
  5. extend testing and refactor it (i don't think we need subfolders for the 3 test files we have)

W&B logs:

@lvwerra lvwerra marked this pull request as ready for review February 21, 2023 17:01
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 21, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this great addition! I left few comments and questions as a first pass!

mini_batch_data,
batch_size=self.config.mini_batch_size,
shuffle=True,
collate_fn=collator,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
collate_fn=collator,
collate_fn=collator,
drop_last=True,

Maybe we can add this to avoid some corner-cases such as the one described on a previous issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, let's also set a warning if that's the case so the user knows that a batch will be dropped.

trl/trainer/ppo_trainer.py Outdated Show resolved Hide resolved
bs = self.config.batch_size
fbs = self.config.forward_batch_size
bs = len(queries)
fbs = min(bs, self.config.forward_batch_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the case where the last element has less instances than the mini_batch_size or the case a users put a batch_size that is smaller than mini_batch_size on the config? If it's the second case we can maybe add a warning on the config, if the first case since we have drop_last=True set here I don't think we'll face this case but I am not sure

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for the case where mini_batch_size is smaller than forward_batch_size during the forward passes inside the minibatch loop. I am also not quite happy with how we do it actually.

trl/trainer/ppo_trainer.py Show resolved Hide resolved
trl/trainer/ppo_trainer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what about completely removing forward_batch_size from the config? I don't think this is a breaking change as the configs cannot be pushed on the Hub, just need to update the examples accodingly. I believe this can be done on a follow up PR too

@lvwerra
Copy link
Member Author

lvwerra commented Feb 22, 2023

The breaking change actually also happens for users who currently use the library with forward_batch_size. What do you think about setting it default to None and overwrite mini_batch_size if it's set to another value with a warning that it affects now also the mini_batch_size if set to a value?

@younesbelkada
Copy link
Contributor

This solution makes a lot of sense yes!

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@lvwerra
Copy link
Member Author

lvwerra commented Feb 22, 2023

Deprecated forward_batch_size: feel free to have a look!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your great work on this! 💯

@lvwerra lvwerra merged commit f1300ec into main Feb 23, 2023
@lvwerra lvwerra deleted the mini-batching branch February 23, 2023 14:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants