Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support packing for pretokenized datasets #1848

Closed
kmehant opened this issue Jul 17, 2024 · 7 comments · Fixed by #2011
Closed

Support packing for pretokenized datasets #1848

kmehant opened this issue Jul 17, 2024 · 7 comments · Fixed by #2011
Labels
🗃️ data Related to data ✨ enhancement New feature or request 🏋 SFT Related to SFT

Comments

@kmehant
Copy link
Contributor

kmehant commented Jul 17, 2024

At this point, trl returns the dataset as is if the provided dataset has signs of being tokenized already.

if column_names and "input_ids" in column_names:

Additionally, I see the ConstantLengthDataset

class ConstantLengthDataset(IterableDataset):

has been written only in support of data that is not pretokenized and it should be possible to extend to pretokenized case as well.

Is there of any interest to support packing for pretokenized datasets? if so, I will be interested to contribute.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@kmehant
Copy link
Contributor Author

kmehant commented Aug 19, 2024

@lvwerra requesting your opinion.

@qgallouedec
Copy link
Member

@kmehant thanks for sharing this feature request.
Can you briefly describe why you need this feature? Or why you can't do without this feature?
It's undoubtedly an interesting feature to have, but I'm worried about the implementation, which risks adding yet another level of complexity. Have you found a way of implementing it? What elements are affected by the changes?

@qgallouedec qgallouedec added the ✨ enhancement New feature or request label Aug 26, 2024
@kmehant
Copy link
Contributor Author

kmehant commented Sep 2, 2024

@qgallouedec thanks for circling back.

In my opinion supporting is not complex. Here is a version implementing this - https://github.com/kmehant/trl/tree/pack-pretok

changes / comparison with main - https://github.com/huggingface/trl/compare/main...kmehant:trl:pack-pretok?expand=1

Steps to try this version

Install trl from my fork

git clone -b pack-pretok https://github.com/kmehant/trl.git
cd trl 
pip install .

Sample training code

from trl import SFTTrainer
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
t = tok.encode("We adopted exactly the same architecture and tokenizer as Llama 2.")
d = {"input_ids": [t]*10}
import datasets
data = datasets.Dataset.from_dict(d)
trainer = SFTTrainer(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    train_dataset=data,
    max_seq_length=10,
    packing=True,
)
trainer.train()

Sample output looks like

{'train_runtime': 18.4487, 'train_samples_per_second': 2.927, 'train_steps_per_second': 0.163, 'train_loss': 2.972621281941732, 'epoch': 3.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:18<00:00,  6.15s/it]
TrainOutput(global_step=3, training_loss=2.972621281941732, metrics={'train_runtime': 18.4487, 'train_samples_per_second': 2.927, 'train_steps_per_second': 0.163, 'total_flos': 3351820124160.0, 'train_loss': 2.972621281941732, 'epoch': 3.0})

Thank you. I can raise a PR out of this and add tests as needed.

@qgallouedec
Copy link
Member

Thanks! It's actually simpler than I expected.
Can you open a PR?
Would it be possible to directly infer if the dataset is tokenized in ConstantLengthDataset?

@kmehant
Copy link
Contributor Author

kmehant commented Sep 3, 2024

@qgallouedec Have raised a PR here - #2011

Would it be possible to directly infer if the dataset is tokenized in ConstantLengthDataset?

Thanks, included that in the PR.

@kmehant
Copy link
Contributor Author

kmehant commented Sep 17, 2024

@qgallouedec any update on this thread? Thanks

@qgallouedec qgallouedec added 🏋 SFT Related to SFT 🗃️ data Related to data labels Oct 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🗃️ data Related to data ✨ enhancement New feature or request 🏋 SFT Related to SFT
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants