Skip to content

Commit

Permalink
fix: address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
  • Loading branch information
kmehant committed Nov 5, 2024
1 parent a9233fd commit 1ae3f93
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,12 @@ def _prepare_dataset(
warnings.warn(
"You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored."
)
formatting_func = lambda x: x["input_ids"]

def formatting_func(x):
return x["input_ids"]

if not packing:
return dataset
warnings.warn(
"Since packing is set to True, though the dataset is pretokenized, it will undergo constant length dataset preparation."
)

# check if torch dataset / dataloader and do nothing
# see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,9 @@ def __init__(
column_names = (
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None
)
if column_names and "input_ids" in column_names:
if column_names is not None and "input_ids" in column_names:
self.pretokenized = True
# since its tokenized unit of buffer size should be tokens
# since the dataset is tokenized, the unit of buffer size should be tokens
self.max_buffer_size = seq_length * num_of_sequences

def __len__(self):
Expand Down

0 comments on commit 1ae3f93

Please sign in to comment.