Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hello,
In this PR, I include everything necessary to perform SFT with the llama model. The main features included are:
As detailed below, in this first beta, we will allow activating and deactivating features 2 & 3. I have designed this to measure the effect of these parameters, although I propose getting rid of them in the final version.
In my first PR in the nanotron repo (huggingface#187), I used as a reference the implementation on
axolotl
. The problem was that it contained padding tokens to fill the sequence length. I finally opted for a padding-free implementation and used the new implementation from HuggingFace Transformers as a reference [1], [2]. I included the scripttools/check_sft.py
to compare the generations of both models (HF & nanotron) and ensure they are the same. I emphasize that the generations are the same and not the logits. This is because, although we have the same parameters in both implementations, we do not perform exactly the same operations. In nanotron, we have 1. Fused QKV matrix, 2. Fused MLP matrix, 3. FA LayerNorm, which produces slightly different logits (with torch.testing.assert_close 99% of the logits are equal with atol=rtol=1e-2), but the important thing is that the generations are the same, especially the most probable first token.Here & here you can observe the wandb runs of the 4 different configs toggling Features 2 & 3. As can be seen, using Feature 3 increases the TFLOPs since
flash_attn_varlen_func
achieves better performance when attending to shorter sequences.Details & Functionality
In this first "Beta," I introduce 1. A new Dataset & ChatTokenizer & Collator and 2. A new Llama model for SFT (
LlamaForSFT
).We will only need to specify in the config file a QA dataset from the HuggingFace Hub. Unlike Nanosets, no preprocessing step is required. In this case, we have an
IterableDataset
that will handle tokenization + sample packing on the fly. The obvious benefit of this is that we don't need to tokenize the data beforehand, but it has a major drawback: It is not trivial to recover the state of the DataLoader to resume training once interrupted. The only solution I know is through torchdata's StatefulDataloaders, which I am already working on for the final version. We can also activate and deactivate features 2 and 3 via the configurationstrain_on_completions_only
andremove_cross_attention
. Finally, remember that we only support the format of conversation datasets from Open-Orca/SlimOrca & Magpie-Align/Magpie-Pro-300K-Filtered, so if you want to use other QA datasets (like this dataset with "content" and "role" keys), you will need to change the dictionary keys.Finally, to apply the chat template and tokenize the data, I included
ChatTokenizer
, very similar to the one included in meta-llama/llama3, with the difference that we will also register THE ROLE of the tokens necessary for feature 2.LlamaForSFT
only supports SFT training. I have removed everything related to the inference of the nanotron checkpoints with the script run_generate.py since we have never tested it nor do we intend to. I included the RoPE embeddings from HF transformers, which, although their performance is not very good compared to FlashAttention's RoPEs written in Triton, are the only ones I have seen that support position ids (necessary for Feature 3). In the future, we could try to write a kernel for this. Also, for Feature 3, it is necessary to useflash_attn_varlen_func
instead offlash_attn_func
.Keep in mind that as we are already packing multiple samples, the
tokens.micro_batch_size
will be always 1. Then, the maximum number of tokens we will have istokens.micro_batch_size * tokens.sequence_length
.TODOs
tools/check_sft.py
tools/todi
convert_hf_nanotron.ipynb