-
Notifications
You must be signed in to change notification settings - Fork 45
[QEff Finetune] Adding dataset padding changes #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
adebe02
Adding dataset padding changes
69caeec
Addressed review comments
526df1d
Adding padding support fro BS > 1 as well
da775b6
Adding padding support fro BS > 1 as well
8d725ab
Addressed the review comments
75451a7
Merge branch 'main' into dataset_padding
quic-swatia 19386e2
Including val_batch_size in padding_dataset()
c5c8544
Changing CI numbers based on code changes and calculating avg-loss ac…
f84d345
Changing atol in tests due to randomness in current stack
2239119
Skipping tests as diff eval loss values are observed in diff runs wit…
6884d76
Merge branch 'main' into dataset_padding
quic-mamta File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,7 +151,7 @@ def train( | |
|
||
# enable profile for qaic | ||
qaic_profile.start_profiling(device, 1) if train_config.use_profiler else None | ||
|
||
num_dummy_samples = 0 | ||
for step, batch in enumerate(train_dataloader): | ||
# resume training from a particular checkpoint, assuming the dataset is not shuffled | ||
if train_config.use_peft and train_config.from_peft_checkpoint: | ||
|
@@ -192,6 +192,17 @@ def train( | |
) as verifier: | ||
model_outputs = model(**batch) | ||
loss = model_outputs.loss # Forward call | ||
if (batch["labels"] != -100).sum() == 0: | ||
loss = loss.nan_to_num(nan=0.0) | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_dummy_samples += train_config.train_batch_size | ||
else: | ||
num_dummy_samples_per_batch = ( | ||
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item() | ||
) | ||
if num_dummy_samples_per_batch > 0: | ||
num_dummy_samples += num_dummy_samples_per_batch | ||
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch | ||
|
||
if train_config.task_type == "seq_classification": | ||
logits = model_outputs.logits | ||
labels = batch["labels"][:, 0] | ||
|
@@ -201,15 +212,25 @@ def train( | |
else: | ||
model_outputs = model(**batch) | ||
loss = model_outputs.loss # Forward call | ||
if (batch["labels"] != -100).sum() == 0: | ||
loss = loss.nan_to_num(nan=0.0) | ||
num_dummy_samples += train_config.train_batch_size | ||
else: | ||
num_dummy_samples_per_batch = ( | ||
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item() | ||
) | ||
if num_dummy_samples_per_batch > 0: | ||
num_dummy_samples += num_dummy_samples_per_batch | ||
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch | ||
|
||
if train_config.task_type == "seq_classification": | ||
logits = model_outputs.logits | ||
labels = batch["labels"][:, 0] | ||
preds = torch.nn.functional.softmax(logits, dim=-1) | ||
acc_helper.forward(preds, labels) | ||
|
||
total_loss += loss.detach().float() | ||
# Accumalate gradients | ||
loss = loss / train_config.gradient_accumulation_steps | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if train_config.enable_ddp: | ||
if local_rank == 0: | ||
if loss <= train_config.convergence_loss: | ||
|
@@ -237,6 +258,17 @@ def train( | |
step_metric_val = float(torch.exp(loss.detach().float())) | ||
train_step_metric.append(step_metric_val) | ||
|
||
# Accumalate gradients | ||
complete_accum_steps = ( | ||
len(train_dataloader) - len(train_dataloader) % train_config.gradient_accumulation_steps | ||
) | ||
if step < complete_accum_steps: | ||
num_samples_in_cur_update = train_config.gradient_accumulation_steps | ||
else: | ||
num_samples_in_cur_update = len(train_dataloader) % train_config.gradient_accumulation_steps | ||
|
||
loss = loss / num_samples_in_cur_update | ||
|
||
if train_config.grad_scaler: | ||
scaler.scale(loss).backward() # backward pass | ||
else: | ||
|
@@ -296,15 +328,30 @@ def train( | |
|
||
if loss_0_counter.item() == train_config.convergence_counter: | ||
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch: | ||
train_epoch_loss = total_loss / (step - intermediate_step) | ||
train_epoch_loss = ( | ||
0.0 | ||
if total_loss == 0.0 | ||
else total_loss / (step - intermediate_step - num_dummy_samples / train_config.train_batch_size) | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
else: | ||
train_epoch_loss = total_loss / step | ||
train_epoch_loss = ( | ||
0.0 | ||
if total_loss == 0.0 | ||
else total_loss / (step + 1 - num_dummy_samples / train_config.train_batch_size) | ||
) | ||
else: | ||
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch: | ||
train_epoch_loss = total_loss / (len(train_dataloader) - intermediate_step) | ||
train_epoch_loss = ( | ||
0.0 | ||
if total_loss == 0.0 | ||
else total_loss / (step - intermediate_step - (num_dummy_samples / train_config.train_batch_size)) | ||
) | ||
else: | ||
train_epoch_loss = total_loss / len(train_dataloader) | ||
|
||
train_epoch_loss = ( | ||
0.0 | ||
if total_loss == 0.0 | ||
else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size)) | ||
) | ||
if train_config.task_type == "seq_classification": | ||
metric_val = acc_helper.compute() | ||
acc_helper.reset() | ||
|
@@ -389,7 +436,6 @@ def train( | |
results["avg_checkpoint_time"] = avg_checkpoint_time | ||
if train_config.save_metrics: | ||
results["metrics_filename"] = metrics_filename | ||
|
||
return results | ||
|
||
|
||
|
@@ -421,6 +467,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |
eval_loss = 0.0 # Initialize evaluation loss | ||
device_type = torch.device(device).type | ||
|
||
num_dummy_samples = 0 | ||
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)): | ||
# stop when the maximum number of eval steps is reached | ||
if train_config.max_eval_step > 0 and step > train_config.max_eval_step: | ||
|
@@ -439,6 +486,17 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |
outputs = model(**batch) | ||
loss = outputs.loss | ||
|
||
if (batch["labels"] != -100).sum() == 0: | ||
loss = loss.nan_to_num(nan=0.0) | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_dummy_samples += 1 | ||
else: | ||
num_dummy_samples_per_batch = ( | ||
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item() | ||
) | ||
if num_dummy_samples_per_batch > 0: | ||
num_dummy_samples += num_dummy_samples_per_batch | ||
loss = loss * train_config.val_batch_size / num_dummy_samples_per_batch | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we not generalize the logic in L492 to L502? I mean else case considers bs>1 use case with 'labels' shape of [batch, seq]. The same logic should be true for bs=1 use case. Why we need separate logic. This same can be done in train case. Try to avoid unnecessary if conditions. |
||
if train_config.task_type == "seq_classification": | ||
logits = outputs.logits | ||
labels = batch["labels"][:, 0] | ||
|
@@ -453,9 +511,10 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |
val_step_metric.append(metric_val) | ||
|
||
eval_loss += loss.detach().float() | ||
|
||
# Compute average loss and metric | ||
eval_epoch_loss = eval_loss / len(eval_dataloader) | ||
eval_epoch_loss = ( | ||
0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size) | ||
) | ||
if train_config.task_type == "seq_classification": | ||
eval_metric = acc_helper.compute() | ||
else: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.