Adding support for training chat models #187
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Caution
🚨 This is a draft, still in development, and further testing needs to be done. Feel free to leave any comments!
This PR includes everything necessary to train chat models with:
This image from @sz128 is very useful to understand 1. & 2.:
I am developing this feature with
axolotl
's implementation as a reference. The current status is as follows:Dataset
IterableDatasets
This time, I have opted for and IterableDataset instead of a map style one. The obvious benefits are that we tokenize on the fly, which allows us to easily experiment with different models/tokenizers/chat templates and saves disk space by not storing the tokens. However, the drawbacks are:
split_dataset_by_node
function, which will divide the dataset'sn_shards
among the number of DP groups. If not evenly divisible, each DP group keeps 1 sample of the dataset, skipping the other examples. This is obviously not very optimal. Also, remember that since we tokenize the data on the fly, even if we divide byn_shards
, each DP group will produce a different amount of tokens. Thus, after X steps, one DP group will have consumed its dataset 1.5 times, while another with longer samples will have done so 0.8 times..skip()
method. This method is not optimal as it consumes all the samples from the Dataset, but it seems they are working on a better solution. This is a problem if you have to skip many samples from a XXXL dataset.num_workers > 1
, as the dataset would need to be divided at the worker level, and then it would not be trivial to recover the state if this value changes.Of all these inconveniences, the one that worries me the most is the third one, but I trust that they will develop an optimal solution soon. We can easily develop solutions for the first and second issues, and the fourth one does not seem too problematic, although we could also address it.
How Samples Are Produced
In short, we extract samples from the dataset and apply the chat template until we can no longer fit a full Question-Answer pair into the sequence length of the sample we are constructing. We save this last Question-Answer pair for the next sample and pad the sample we are constructing (In the case of the Llama3 tokenizer as we don't have a pad token we use the <|eot_id|> token). We do this so that each sample has several completed Question-Answer pairs. This packing is greedy, although there are more complex strategies to minimize the number of pad tokens.
The important thing here is that we have developed the
ChatTokenizer
class to apply the chat template manually and not use solutions like theapply_chat_template
method of tokenizers. We do this to know at the token level if each one belongs to the assistant's responses or not for the feature of training only on the assistant's tokens. I have added an assert to verify that the result of applying the chat template is exactly the same as theapply_chat_template
method of tokenizers.Dataset Samples
I have developed this notebook so you can check the batches produced by the DataLoader. In summary, the most relevant features are:
Note
The
label id
token '-' is actually -100. We switch it becausetokenizer.convert_ids_to_tokens
can't convert '-100' token.position_ids
. This will be relevant later for specifying FA2 to not attend to other samples.Other Considerations
Collator
The collator is very similar to
DataCollatorForCLM
, except we now add theposition_ids
. I have also removed several assertions.DataLoader
The DataLoader is pretty simple. As I mentioned, it is not trivial to work with
num_workers
.Config
I have added a new configuration called
ChatDatasetsArgs
which includes:hf_dataset
: The name or path of the dataset we will stream.hf_dataset_split
("train"): The split of the dataset.conversation_column_name
: The name of the column from which to extract the conversations.For debugging purposes, deleted in the final release:
train_on_completions_only
: Whether to just train on completions or not.remove_cross_attention
: Whether to just attend to the tokens from the same sample or to all (Vanilla mechanism).As I already mentioned, the final two configurations will be for evaluating the effect of these two functionalities. I would remove them for the final release since I do not see the benefit of not activating them.
What Is Still Missing:
TODOs: