-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add minibatching #153
add minibatching #153
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this great addition! I left few comments and questions as a first pass!
mini_batch_data, | ||
batch_size=self.config.mini_batch_size, | ||
shuffle=True, | ||
collate_fn=collator, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
collate_fn=collator, | |
collate_fn=collator, | |
drop_last=True, |
Maybe we can add this to avoid some corner-cases such as the one described on a previous issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, let's also set a warning if that's the case so the user knows that a batch will be dropped.
trl/trainer/ppo_trainer.py
Outdated
bs = self.config.batch_size | ||
fbs = self.config.forward_batch_size | ||
bs = len(queries) | ||
fbs = min(bs, self.config.forward_batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is the case where the last element has less instances than the mini_batch_size
or the case a users put a batch_size
that is smaller than mini_batch_size
on the config? If it's the second case we can maybe add a warning on the config, if the first case since we have drop_last=True
set here I don't think we'll face this case but I am not sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for the case where mini_batch_size
is smaller than forward_batch_size
during the forward passes inside the minibatch loop. I am also not quite happy with how we do it actually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, what about completely removing forward_batch_size
from the config? I don't think this is a breaking change as the configs cannot be pushed on the Hub, just need to update the examples accodingly. I believe this can be done on a follow up PR too
The breaking change actually also happens for users who currently use the library with |
This solution makes a lot of sense yes! |
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Deprecated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your great work on this! 💯
Until now the PPO mini batch size has been hardcoded to 1. This PR aims to change it by refactoring the forward/backward passing logic.
In summary this PR does the following things:
batched_forward_pass
returns new amask
which can be used to mask parts of the sequence to be ignoreddataloader
with themini_batch_size
to sample from the current PPO batchloss
method we replace all operations affected by masked parts of the sequence with masked ones (masked_mean
,masked_whiten
)compute_logits_vpred
and usebatched_forward_pass
for everythingW&B logs: