Skip to content

Commit 547d924

Browse files
authored
Add shuffle_dataset option to SFTTrainer (#4564)
1 parent b01f8ca commit 547d924

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

trl/trainer/sft_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class SFTConfig(TrainingArguments):
6464
max_length (`int` or `None`, *optional*, defaults to `1024`):
6565
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
6666
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
67+
shuffle_dataset (`bool`, *optional*, defaults to `False`):
68+
Whether to shuffle the dataset.
6769
packing (`bool`, *optional*, defaults to `False`):
6870
Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce
6971
padding. Uses `max_length` to define sequence length.
@@ -197,6 +199,10 @@ class SFTConfig(TrainingArguments):
197199
"sequence length."
198200
},
199201
)
202+
shuffle_dataset: bool = field(
203+
default=False,
204+
metadata={"help": "Whether to shuffle the dataset."},
205+
)
200206
packing: bool = field(
201207
default=False,
202208
metadata={

trl/trainer/sft_trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,11 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
10711071

10721072
dataset = dataset.select_columns(columns)
10731073

1074+
# Shuffle the dataset before packing. When using wrapped packing, it's important to shuffle before
1075+
# packing as well to avoid correlations between sequences packed together.
1076+
if args.shuffle_dataset:
1077+
dataset = dataset.shuffle(seed=args.seed)
1078+
10741079
# Packing adds new column "seq_lengths" needed for document aware FlashAttention
10751080
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
10761081
elif args.max_length is not None:
@@ -1083,6 +1088,9 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
10831088
column_names = get_dataset_column_names(dataset)
10841089
dataset = dataset.select_columns(collator_expected_keys.intersection(column_names))
10851090

1091+
if args.shuffle_dataset:
1092+
dataset = dataset.shuffle(seed=args.seed)
1093+
10861094
return dataset
10871095

10881096
def _set_signature_columns_if_needed(self):

0 commit comments

Comments
 (0)