Skip to content

[QEff Finetune] : Made fixes to training script #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

quic-mamta
Copy link
Contributor

@quic-mamta quic-mamta commented Jun 10, 2025

  • Added padding of dataset to make it divisible by batch_size * num_devices.
  • Updated dataloader code to prevent dropping the last slice from the dataset.
  • Updated samsum dataset name from "Samsung/samsum" to "knkarthick/samsum".
  • Added loss modified loss functions to support the loss_weight parameter.
  • Refactored and corrected the logging of loss and ppl across devices.

Below are the numbers with this PR:

Dataset: Samsum
Model: Llama-3.2-1B
Epoch: 1

Sr. No. # Devices Grad. Accum. BS Global BS Train Loss Test Loss Train PPL Test PPL
1 48 4 1 192 1.3206 1.2855 3.7471 3.6360

Dataset: Samsum
Model: Llama-3.1-8B
Epoch: 1

Sr. No. # Devices Grad. Accum. BS Global BS Train Loss Test Loss Train PPL Test PPL
1 4 4 1 16 1.0120 1.0725 2.7510 2.9234



def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
dataset = datasets.load_dataset("Samsung/samsum", split=split, trust_remote_code=True)
dataset = datasets.load_dataset("knkarthick/samsum", split=split, trust_remote_code=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check if this dataset can be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not distributing this dataset hence, it should not be a problem.

@quic-mamta quic-mamta force-pushed the jitender_fixes branch 2 times, most recently from 6d833cf to d269d0c Compare June 10, 2025 10:57
@quic-mamta quic-mamta changed the title Made fixes to training script based on recent findings. [QEff Finetune] : Made fixes to training script Jun 12, 2025
@@ -235,11 +241,23 @@ def train(
train_step_metric.append(step_metric_val)

if train_config.grad_scaler:
scaler.scale(loss).backward() # backward pass
if train_config.enable_ddp:
with model.no_sync():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will result in no syncing of gradients at any step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct. Removed this no_sync change from this PR. We will raise separate PR for that.

if train_config.enable_ddp:
# FIXME: We can not stop transfer of gradient across devices every time.
# In grad accumulation last step should transfer gradients across devices.
with model.no_sync():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will result in no syncing of gradients at any step here as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct. Removed this no_sync change from this PR. We will raise separate PR for that.

Signed-off-by: Mamta Singh <mamtsing@qti.qualcomm.com>
…ght parameter to make the loss for padded samples as zero.

Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
…well.

Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
… and zero the loss for padded samples.

Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
…e loss fn.

Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
Signed-off-by: Meet Patel <meetkuma@qti.qualcomm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants