-
Notifications
You must be signed in to change notification settings - Fork 45
[QEff Finetune] Adding dataset padding changes #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
b8104e1
to
adebe02
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please generate the ppl numbers across different ddp devices, grad accum step to make this change concrete.
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please try to fix the comments at the earliest so that this can be merged ASAP.
) | ||
if num_dummy_samples_per_batch > 0: | ||
num_dummy_samples += num_dummy_samples_per_batch | ||
loss = loss * train_config.val_batch_size / num_dummy_samples_per_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why train_config.val_batch_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in latest.
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
6dc7950
to
8d725ab
Compare
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) | ||
return kwargs | ||
|
||
|
||
def get_num_ddp_devices(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is generic function, not specific to dataset_utils. This should be ideally in helper.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
) | ||
if num_dummy_samples_per_batch > 0: | ||
num_dummy_samples += num_dummy_samples_per_batch | ||
loss = loss * train_config.val_batch_size / num_dummy_samples_per_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in latest.
num_replicas = 1 | ||
if train_config.enable_ddp: | ||
num_replicas = dist.get_world_size() | ||
remainder = len(dataset) % (num_replicas * train_config.train_batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are padding based on train bs. What happens in case of test dataset with valid_batch_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if num_dummy_samples_per_batch > 0: | ||
num_dummy_samples += num_dummy_samples_per_batch | ||
loss = loss * train_config.val_batch_size / num_dummy_samples_per_batch | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not generalize the logic in L492 to L502?
I mean else case considers bs>1 use case with 'labels' shape of [batch, seq]. The same logic should be true for bs=1 use case. Why we need separate logic.
This same can be done in train case. Try to avoid unnecessary if conditions.
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
…c to steps. Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
…h code before PR quic#478 Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
1da94fd
to
2239119
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, good work in identifying all the corner cases and fixing those, Swati! :)
Padding the dataset with dummy samples (they won't contribute in total_loss) to make the #samples a multiple of degree of ddp*batch_size) in case of