Skip to content

[QEff Finetune] Adding dataset padding changes #478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

quic-swatia
Copy link
Contributor

@quic-swatia quic-swatia commented Jun 24, 2025

Padding the dataset with dummy samples (they won't contribute in total_loss) to make the #samples a multiple of degree of ddp*batch_size) in case of

  1. Fine tuning through DDP
  2. train_batch_size > 1 or val_batch_size > 0

Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
Copy link
Contributor

@quic-meetkuma quic-meetkuma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please generate the ppl numbers across different ddp devices, grad accum step to make this change concrete.

Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
@quic-swatia quic-swatia self-assigned this Jun 27, 2025
Swati Allabadi added 2 commits June 27, 2025 18:27
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
Copy link
Contributor

@quic-meetkuma quic-meetkuma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to fix the comments at the earliest so that this can be merged ASAP.

)
if num_dummy_samples_per_batch > 0:
num_dummy_samples += num_dummy_samples_per_batch
loss = loss * train_config.val_batch_size / num_dummy_samples_per_batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why train_config.val_batch_size?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in latest.

Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
return kwargs


def get_num_ddp_devices():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is generic function, not specific to dataset_utils. This should be ideally in helper.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)
if num_dummy_samples_per_batch > 0:
num_dummy_samples += num_dummy_samples_per_batch
loss = loss * train_config.val_batch_size / num_dummy_samples_per_batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in latest.

num_replicas = 1
if train_config.enable_ddp:
num_replicas = dist.get_world_size()
remainder = len(dataset) % (num_replicas * train_config.train_batch_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are padding based on train bs. What happens in case of test dataset with valid_batch_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if num_dummy_samples_per_batch > 0:
num_dummy_samples += num_dummy_samples_per_batch
loss = loss * train_config.val_batch_size / num_dummy_samples_per_batch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not generalize the logic in L492 to L502?

I mean else case considers bs>1 use case with 'labels' shape of [batch, seq]. The same logic should be true for bs=1 use case. Why we need separate logic.

This same can be done in train case. Try to avoid unnecessary if conditions.

quic-swatia and others added 4 commits July 2, 2025 13:41
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
…c to steps.

Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
quic-swatia pushed a commit to quic-swatia/efficient-transformers that referenced this pull request Jul 4, 2025
…h code before PR quic#478

Signed-off-by: Swati Allabadi <sallabad@qti.qualcomm.com>
Copy link
Contributor

@quic-meetkuma quic-meetkuma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, good work in identifying all the corner cases and fixing those, Swati! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants