Skip to content

Commit d3601bb

Browse files
quic-meetkumaeplatero97
authored andcommitted
[QEff Finetune]: Added support for gradient checkpointing in the finetuning script. (quic#338)
Added --gradient_checkpointing new CLI flag to enable this feature. Currently this is enabled for all the HF models which has "supports_gradient_checkpointing" attribute set to True. --------- Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> Signed-off-by: eplatero <quic_eplatero@quicinc.com>
1 parent dcd407d commit d3601bb

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

QEfficient/finetune/dataset/dataset_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from QEfficient.finetune.dataset.samsum_dataset import (
2525
get_preprocessed_samsum as get_samsum_dataset,
2626
)
27+
from QEfficient.finetune.dataset.samsum_dataset import (
28+
get_samsum_collate_fn,
29+
)
2730

2831
DATASET_PREPROC = {
2932
"alpaca_dataset": partial(get_alpaca_dataset),
@@ -36,3 +39,7 @@
3639
DATALOADER_COLLATE_FUNC = {
3740
"custom_dataset": get_data_collator,
3841
}
42+
DATALOADER_COLLATE_FUNC = {
43+
"custom_dataset": get_data_collator,
44+
"samsum_dataset": get_samsum_collate_fn,
45+
}

QEfficient/finetune/dataset/samsum_dataset.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# -----------------------------------------------------------------------------
77

88
import datasets
9+
import torch
10+
from torch.nn.utils.rnn import pad_sequence
911

1012

1113
def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
@@ -46,3 +48,22 @@ def tokenize_add_label(sample):
4648
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
4749

4850
return dataset
51+
52+
53+
def collate_fn(batch):
54+
eos_token = batch[0]["input_ids"][-1]
55+
56+
input_ids = pad_sequence(
57+
[torch.tensor(b["input_ids"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=eos_token
58+
)
59+
attn_mask = pad_sequence(
60+
[torch.tensor(b["attention_mask"], dtype=torch.int32) for b in batch], batch_first=True, padding_value=0
61+
)
62+
labels = pad_sequence(
63+
[torch.tensor(b["labels"], dtype=torch.long) for b in batch], batch_first=True, padding_value=eos_token
64+
)
65+
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}
66+
67+
68+
def get_samsum_collate_fn(dataset_processer, dataset_config):
69+
return collate_fn

0 commit comments

Comments
 (0)