Skip to content

Commit

Permalink
feat: add support for packing tokenized datasetS
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
  • Loading branch information
kmehant committed Sep 3, 2024
1 parent d60a1f5 commit 9ddff15
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
10 changes: 7 additions & 3 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,12 @@ def _prepare_dataset(
warnings.warn(
"You passed a dataset that is already processed (contains an `input_ids` field) together with a valid formatting function. Therefore `formatting_func` will be ignored."
)

return dataset
formatting_func = lambda x: x["input_ids"]
if not packing:
return dataset
warnings.warn(
"Since packing is set to True, though the dataset is pretokenized, it will undergo constant length dataset preparation."
)

# check if torch dataset / dataloader and do nothing
# see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check
Expand Down Expand Up @@ -650,7 +654,7 @@ def data_generator(constant_length_iterator):
return packed_dataset
else:
raise ValueError(
"You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`."
"You need to pass a `dataset_text_field` or `formatting_func` argument or a pretokenized dataset to the SFTTrainer if you want to use the `ConstantLengthDataset`."
)

def _trl_activate_neftune(self, model):
Expand Down
27 changes: 19 additions & 8 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import datasets
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -510,17 +511,24 @@ def __init__(
self.shuffle = shuffle
self.append_concat_token = append_concat_token
self.add_special_tokens = add_special_tokens
if formatting_func is None:
self.formatting_func = formatting_func
if self.formatting_func is None:
self.formatting_func = lambda x: x[dataset_text_field]
else:
self.formatting_func = formatting_func

if formatting_func is not None:
if formatting_func.__code__.co_argcount > 1:
if self.formatting_func is not None:
if self.formatting_func.__code__.co_argcount > 1:
warnings.warn(
"The passed formatting_func has more than one argument. Usually that function should have a single argument `example`"
" which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."
)
self.pretokenized = False
column_names = (
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None
)
if column_names and "input_ids" in column_names:
self.pretokenized = True
# since its tokenized unit of buffer size should be tokens
self.max_buffer_size = seq_length * num_of_sequences

def __len__(self):
return len(self.dataset)
Expand All @@ -543,9 +551,12 @@ def __iter__(self):
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
"input_ids"
]
if self.pretokenized:
tokenized_inputs = buffer
else:
tokenized_inputs = self.tokenizer(
buffer, add_special_tokens=self.add_special_tokens, truncation=False
)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
if self.append_concat_token:
Expand Down

0 comments on commit 9ddff15

Please sign in to comment.