-
Notifications
You must be signed in to change notification settings - Fork 45
[QEff Finetune] Adding dataset padding changes #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
adebe02
69caeec
526df1d
da775b6
8d725ab
75451a7
19386e2
c5c8544
f84d345
2239119
6884d76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,9 @@ | |
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
import os | ||
|
||
import datasets | ||
import torch | ||
import torch.distributed as dist | ||
from transformers.data import DataCollatorForSeq2Seq | ||
|
@@ -54,27 +56,58 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split): | |
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False | ||
) | ||
kwargs["batch_size"] = batch_size | ||
kwargs["drop_last"] = True | ||
kwargs["drop_last"] = False | ||
else: | ||
kwargs["batch_size"] = batch_size | ||
kwargs["drop_last"] = True | ||
kwargs["drop_last"] = False | ||
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer) | ||
return kwargs | ||
|
||
|
||
def get_num_ddp_devices(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is generic function, not specific to dataset_utils. This should be ideally in helper.py. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
return int(os.getenv("WORLD_SIZE", 1)) | ||
|
||
|
||
def padding_dataset(train_config, dataset): | ||
dataset = dataset.map(lambda x: {"input_length": len(x["input_ids"])}) | ||
if train_config.enable_sorting_for_ddp: | ||
dataset = dataset.sort("input_length") | ||
dataset = dataset.remove_columns("input_length") | ||
dummy_row = next(iter(dataset)) | ||
dummy_row["labels"] = torch.tensor([-100] * len(dummy_row["labels"])) | ||
padding_size = 0 | ||
num_replicas = get_num_ddp_devices() | ||
remainder = len(dataset) % (num_replicas * train_config.train_batch_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are padding based on train bs. What happens in case of test dataset with valid_batch_size? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
padding_size = (num_replicas * train_config.train_batch_size) - remainder | ||
|
||
dummy_data = [dummy_row.copy() for _ in range(padding_size)] | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dummy_dataset = datasets.Dataset.from_list(dummy_data) | ||
combined_dataset = datasets.concatenate_datasets([dataset, dummy_dataset]) | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return combined_dataset | ||
|
||
|
||
def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"): | ||
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split, context_length=train_config.context_length) | ||
|
||
if ( | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
train_config.enable_ddp | ||
or (split == "train" and train_config.train_batch_size > 1) | ||
or (split != "train" and train_config.val_batch_size > 1) | ||
): | ||
dataset = padding_dataset(train_config, dataset) | ||
|
||
dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split) | ||
|
||
# FIXME (Meet): Add custom data collator registration from the outside by the user. | ||
custom_data_collator = get_custom_data_collator(tokenizer, dataset_config) | ||
|
||
if custom_data_collator: | ||
print("custom_data_collator is used") | ||
dl_kwargs["collate_fn"] = custom_data_collator | ||
|
||
print(f"length of dataset_{split}", len(dataset)) | ||
|
||
# Create data loader | ||
|
||
dataloader = torch.utils.data.DataLoader( | ||
dataset, | ||
num_workers=train_config.num_workers_dataloader, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,7 +151,7 @@ def train( | |
|
||
# enable profile for qaic | ||
qaic_profile.start_profiling(device, 1) if train_config.use_profiler else None | ||
|
||
num_dummy_samples = 0 | ||
for step, batch in enumerate(train_dataloader): | ||
# resume training from a particular checkpoint, assuming the dataset is not shuffled | ||
if train_config.use_peft and train_config.from_peft_checkpoint: | ||
|
@@ -192,6 +192,17 @@ def train( | |
) as verifier: | ||
model_outputs = model(**batch) | ||
loss = model_outputs.loss # Forward call | ||
if (batch["labels"] != -100).sum() == 0: | ||
loss = loss.nan_to_num(nan=0.0) | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_dummy_samples += train_config.train_batch_size | ||
else: | ||
num_dummy_samples_per_batch = ( | ||
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item() | ||
) | ||
if num_dummy_samples_per_batch > 0: | ||
num_dummy_samples += num_dummy_samples_per_batch | ||
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch | ||
|
||
if train_config.task_type == "seq_classification": | ||
logits = model_outputs.logits | ||
labels = batch["labels"][:, 0] | ||
|
@@ -201,15 +212,25 @@ def train( | |
else: | ||
model_outputs = model(**batch) | ||
loss = model_outputs.loss # Forward call | ||
if (batch["labels"] != -100).sum() == 0: | ||
loss = loss.nan_to_num(nan=0.0) | ||
num_dummy_samples += train_config.train_batch_size | ||
else: | ||
num_dummy_samples_per_batch = ( | ||
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item() | ||
) | ||
if num_dummy_samples_per_batch > 0: | ||
num_dummy_samples += num_dummy_samples_per_batch | ||
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch | ||
|
||
if train_config.task_type == "seq_classification": | ||
logits = model_outputs.logits | ||
labels = batch["labels"][:, 0] | ||
preds = torch.nn.functional.softmax(logits, dim=-1) | ||
acc_helper.forward(preds, labels) | ||
|
||
total_loss += loss.detach().float() | ||
# Accumalate gradients | ||
loss = loss / train_config.gradient_accumulation_steps | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if train_config.enable_ddp: | ||
if local_rank == 0: | ||
if loss <= train_config.convergence_loss: | ||
|
@@ -237,6 +258,17 @@ def train( | |
step_metric_val = float(torch.exp(loss.detach().float())) | ||
train_step_metric.append(step_metric_val) | ||
|
||
# Accumalate gradients | ||
complete_accum_steps = ( | ||
len(train_dataloader) - len(train_dataloader) % train_config.gradient_accumulation_steps | ||
) | ||
if step < complete_accum_steps: | ||
num_samples_in_cur_update = train_config.gradient_accumulation_steps | ||
else: | ||
num_samples_in_cur_update = len(train_dataloader) % train_config.gradient_accumulation_steps | ||
|
||
loss = loss / num_samples_in_cur_update | ||
|
||
if train_config.grad_scaler: | ||
scaler.scale(loss).backward() # backward pass | ||
else: | ||
|
@@ -296,14 +328,31 @@ def train( | |
|
||
if loss_0_counter.item() == train_config.convergence_counter: | ||
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch: | ||
train_epoch_loss = total_loss / (step - intermediate_step) | ||
train_epoch_loss = ( | ||
0.0 | ||
if total_loss == 0.0 | ||
else total_loss / (step - intermediate_step - num_dummy_samples / train_config.train_batch_size) | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
else: | ||
train_epoch_loss = total_loss / step | ||
train_epoch_loss = ( | ||
0.0 | ||
if total_loss == 0.0 | ||
else total_loss / (step - num_dummy_samples / train_config.train_batch_size) | ||
) | ||
else: | ||
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch: | ||
train_epoch_loss = total_loss / (len(train_dataloader) - intermediate_step) | ||
train_epoch_loss = ( | ||
0.0 | ||
if total_loss == 0.0 | ||
else total_loss | ||
/ (len(train_dataloader) - intermediate_step - (num_dummy_samples / train_config.train_batch_size)) | ||
) | ||
else: | ||
train_epoch_loss = total_loss / len(train_dataloader) | ||
train_epoch_loss = ( | ||
0.0 | ||
if total_loss == 0.0 | ||
else total_loss / (len(train_dataloader) - (num_dummy_samples / train_config.train_batch_size)) | ||
) | ||
|
||
if train_config.task_type == "seq_classification": | ||
metric_val = acc_helper.compute() | ||
|
@@ -421,6 +470,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |
eval_loss = 0.0 # Initialize evaluation loss | ||
device_type = torch.device(device).type | ||
|
||
num_dummy_samples = 0 | ||
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)): | ||
# stop when the maximum number of eval steps is reached | ||
if train_config.max_eval_step > 0 and step > train_config.max_eval_step: | ||
|
@@ -439,6 +489,17 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |
outputs = model(**batch) | ||
loss = outputs.loss | ||
|
||
if (batch["labels"] != -100).sum() == 0: | ||
loss = loss.nan_to_num(nan=0.0) | ||
quic-swatia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_dummy_samples += 1 | ||
else: | ||
num_dummy_samples_per_batch = ( | ||
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item() | ||
) | ||
if num_dummy_samples_per_batch > 0: | ||
num_dummy_samples += num_dummy_samples_per_batch | ||
loss = loss * train_config.val_batch_size / num_dummy_samples_per_batch | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we not generalize the logic in L492 to L502? I mean else case considers bs>1 use case with 'labels' shape of [batch, seq]. The same logic should be true for bs=1 use case. Why we need separate logic. This same can be done in train case. Try to avoid unnecessary if conditions. |
||
if train_config.task_type == "seq_classification": | ||
logits = outputs.logits | ||
labels = batch["labels"][:, 0] | ||
|
@@ -455,7 +516,11 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |
eval_loss += loss.detach().float() | ||
|
||
# Compute average loss and metric | ||
eval_epoch_loss = eval_loss / len(eval_dataloader) | ||
eval_epoch_loss = ( | ||
0.0 | ||
if eval_loss == 0.0 | ||
else eval_loss / (len(eval_dataloader) - num_dummy_samples / train_config.val_batch_size) | ||
) | ||
if train_config.task_type == "seq_classification": | ||
eval_metric = acc_helper.compute() | ||
else: | ||
|
Uh oh!
There was an error while loading. Please reload this page.