Paper | Tweet | Podcast | Media
This is the official repository for the paper Pre-training Small Base LMs with Fewer Tokens.
We study the effectiveness of a simple approach to develop a small base language model (LM) starting from an existing large base LM: first inherit a few transformer blocks from the larger LM, and then continually train this smaller model on a very small subset (0.1%) of the raw pre-training data of the larger model. We call our simple recipe Inheritune and first demonstrate it for building a small base LM with 1.5B parameters using 1B tokens (and a starting larger LM of 3B parameters); we do this using a single A6000 GPU for less than half a day. Across 9 diverse evaluation datasets as well as the MMLU benchmark, the resulting model compares favorably to publicly available similar sized base models, some of which have been trained using 50-1000 times more tokens.
We also investigate Inheritune, a slightly different setting where we train small LMs utilizing larger LMs and their full pre-training dataset. Here we show that smaller LMs trained utilizing some of the layers of GPT2-medium (355M) and GPT-2-large (770M) can effectively match the validation loss of their bigger counterparts when trained from scratch for the same number of training steps on OpenWebText dataset with 9B tokens. We analyze Inheritune with extensive experiments and demonstrate it efficacy on diverse settings.
Performance of our 1.5B base LM derived using 1B data with Inheritune on an average of 9 different datasets (left) and MMLU benchmark (right) that evaluates commonsense, truthfulness, natural language inference and language understanding. We compare our model's performance with reference model-OpenLLamA-3B (2x size), other small base LMs of size 1B-2B parameters such as MPT-1.3B, OPT-1.3B, Pythia-1.4B (pre-trained from scratch) and ShearLLaMA-1.5B (pruned and continually trained using existing large base LM).
Below is the comparison of our target model with reference models and other baseline models of similar size when pre-trained from scratch and pre-trained with inherited weights and pruning. Our model, although trained with fewer tokens, achieves comparable performance. We have highlighted all scores where our model achieves at least 90% of the score compared to its reference language model or outperforms at least two of the baseline models. All tasks are evaluated using 0-shot except MMLU, which is 5-shot. The models marked with n/a are trained from scratch.
Model | Commonsense Reasoning | |||||
---|---|---|---|---|---|---|
Name (# train tokens) | Reference | Winograd | PIQA | Boolq | WinoGrande | Logiqa |
OpenLLaMA-3B (1T) | n/a | 63.46 | 74.97 | 67.18 | 62.27 | 28.4 |
OPT-1.3B (300B) | n/a | 38.46 | 71.82 | 57.83 | 59.51 | 27.04 |
Pythia-1.4B (300B) | n/a | 36.54 | 70.89 | 63.12 | 56.99 | 27.65 |
MPT-1.3B (200B) | n/a | 63.46 | 71.44 | 50.89 | 58.09 | 28.26 |
Sheared LLaMA-1.3B (50B) | LLaMA2-7B | 36.54 | 73.45 | 62.02 | 58.17 | 27.34 |
Ours-1.5B (1B) | OpenLLaMA-3B | 50.96 | 56.47 | 61.68 | 51.69 | 25.19 |
Model | Lang. Understanding & Inference | Factuality | ||||
---|---|---|---|---|---|---|
Name (# train tokens) | Reference | MMLU(5) | WNLI | QNLI | MNLI | TruthfulQA |
OpenLLaMA-3B (1T) | n/a | 27.21 | 50.7 | 51.3 | 37.3 | 35 |
OPT-1.3B (300B) | n/a | 24.96 | 42.25 | 51.29 | 35.82 | 38.67 |
Pythia-1.4B (300B) | n/a | 25.56 | 53.52 | 49.48 | 32.76 | 38.66 |
MPT-1.3B (200B) | n/a | 25.82 | 40.85 | 50.52 | 35.93 | 38.68 |
Sheared LLaMA-1.3B (50B) | LLaMA2-7B | 25.71 | 49.3 | 50.98 | 37.94 | 37.14 |
Ours-1.5B (1B) | OpenLLaMA-3B | 25.67 | 43.66 | 49.41 | 34.42 | 48.61 |
Below is the pre-training and downstream performance of GPT-2 Medium and Large models, evaluated using validation loss, Wikitext, and Lambada OpenAI downstream tasks. Smaller models derived using our method perform comparably to their full-sized counterparts. Models initialized with our method show better performance than those with random initialization.
Models | Layers | Initialization | Steps | Pre-train Val Loss (↓) | Downstream (0-shot) | |
---|---|---|---|---|---|---|
Wikitext (↓) | Lambada | |||||
GPT-2 Large | 36 | rand init | 100K | 2.85 | 34.84 | 34.14 |
18 | rand init | 100K | 2.97 | 37.63 | 30.97 | |
18 | rand init | 200K | 2.84 | -- | -- | |
18 | Ours | 100K | 2.80 | 35.38 | 34.64 | |
GPT-2 Medium | 24 | rand init | 100K | 2.81 | 31.93 | 36.54 |
16 | rand init | 100K | 2.86 | 33.67 | 34.60 | |
16 | rand init | 200K | 2.83 | -- | -- | |
12 | Ours | 100K | 2.87 | -- | -- | |
14 | Ours | 100K | 2.84 | -- | -- | |
Final Model → | 16 | Ours | 100K | 2.81 | 32.04 | 35.96 |
Note: The models marked with 'rand init' are randomly initialized. The row labeled 'Final Model →' indicates the end results after 3 rounds of our method on a GPT-2 medium model to achieve benchmark val loss.
[2024-04-22] We've released the first version of codebase for Inheritune in low data regime and also full data regime.
[2024-04-22] We've added the discussions option at the top for community feedback and discussions. Feel free to suggest new experiments and post your results.
If you find this work helpful, please consider citing us:
@inproceedings{Sanyal2024pretraining,
title = {Pre-training Small Base LMs with Fewer Tokens},
author = {Sunny Sanyal and Sujay Sanghavi and Alexandros G. Dimakis},
year = {2024},
url={https://arxiv.org/abs/2404.08634}
}
The training code for small language model 1B-2B is mainly adapted from litgpt. The code for GPT2 experiments are mainly adapted from Sophia and nanoGPT.
The llama image is created using DALLE.