Skip to content

Commit

Permalink
Add llama2 unit tests (rasbt#372)
Browse files Browse the repository at this point in the history
* add llama2 unit tests

* update

* updates

* updates

* update file path

* update requirements file

* rmsnorm test

* update
  • Loading branch information
rasbt authored Sep 26, 2024
1 parent a6d8e93 commit b56d0b2
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 43 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ jobs:
python -m pip install --upgrade pip
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests-macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ jobs:
python -m pip install --upgrade pip
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests-old-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ jobs:
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install torch==${{ matrix.pytorch-version }}
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ jobs:
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install matplotlib==3.9.0
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- name: Test Selected Python Scripts
shell: bash
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
Expand Down
84 changes: 41 additions & 43 deletions ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@
"base_uri": "https://localhost:8080/"
},
"id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
"outputId": "d0fc89be-74a3-40d0-bc4d-7f6f1febf2cd"
"outputId": "7ce8fe41-1c24-4f0b-a8d9-352b4af1b46b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.24.7\n",
"sentencepiece version: 0.1.99\n",
"sentencepiece version: 0.2.0\n",
"torch version: 2.4.1+cu121\n"
]
}
Expand Down Expand Up @@ -421,41 +421,39 @@
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))\n",
" inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length)\n",
"\n",
" # Compute the angles using inverse frequencies and positions\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, emb_dim // 2)\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Precompute sine and cosine of the angles\n",
" sin = torch.sin(angles) # Shape: (context_length, emb_dim // 2)\n",
" cos = torch.cos(angles) # Shape: (context_length, emb_dim // 2)\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
"\n",
" return sin, cos\n",
" # Precompute sine and cosine\n",
" cos = torch.cos(angles)\n",
" sin = torch.sin(angles)\n",
"\n",
" return cos, sin\n",
"\n",
"def compute_rope(x, sin, cos):\n",
"def compute_rope(x, cos, sin):\n",
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
"\n",
" # Split x into even and odd parts\n",
" x1 = x[..., ::2] # Shape: (batch_size, num_heads, seq_len, head_dim // 2)\n",
" x2 = x[..., 1::2]\n",
" # Split x into first half and second half\n",
" x1 = x[..., : head_dim // 2] # First half\n",
" x2 = x[..., head_dim // 2 :] # Second half\n",
"\n",
" # Ensure sin and cos have correct shapes\n",
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)\n",
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
" # Adjust sin and cos shapes\n",
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
"\n",
" # Apply the rotary transformation\n",
" x_rotated_0 = x1 * cos - x2 * sin\n",
" x_rotated_1 = x1 * sin + x2 * cos\n",
"\n",
" # Interleave x_rotated_0 and x_rotated_1\n",
" x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1)\n",
" x_rotated = x_rotated.flatten(-2)\n",
" rotated = torch.cat((-x2, x1), dim=-1)\n",
" x_rotated = (x * cos) + (rotated * sin)\n",
"\n",
" return x_rotated.to(dtype=x.dtype)"
]
Expand Down Expand Up @@ -486,16 +484,16 @@
"head_dim = 16\n",
"\n",
"# Instantiate RoPE parameters\n",
"sin, cos = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
"cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
"\n",
"# Dummy query and key tensors\n",
"torch.manual_seed(123)\n",
"queries = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
"keys = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
"\n",
"# Apply rotary position embeddings\n",
"queries_rot = compute_rope(queries, sin, cos)\n",
"keys_rot = compute_rope(keys, sin, cos)"
"queries_rot = compute_rope(queries, cos, sin)\n",
"keys_rot = compute_rope(keys, cos, sin)"
]
},
{
Expand Down Expand Up @@ -554,9 +552,9 @@
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" ################################### NEW ###################################\n",
" sin, cos = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
" self.register_buffer(\"sin\", sin)\n",
" cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
" self.register_buffer(\"cos\", cos)\n",
" self.register_buffer(\"sin\", sin)\n",
" ###########################################################################\n",
"\n",
"\n",
Expand Down Expand Up @@ -736,7 +734,7 @@
"cell_type": "markdown",
"id": "ba5d991a-559b-47be-96f4-31b881ab2da8",
"metadata": {
"id": "aa79780d-74a8-4ee0-934a-9ad63205a02e"
"id": "ba5d991a-559b-47be-96f4-31b881ab2da8"
},
"source": [
"- As you may recall from [chapter 5](../01_main-chapter-code/ch05.ipynb), the `TransformerBlock` is a repeated block within the main model\n",
Expand Down Expand Up @@ -918,7 +916,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "6079f747-8f20-4c6b-8d38-7156f1101729",
"outputId": "78ab929e-ac78-4b16-ddb1-704d45ee69a8"
"outputId": "1ca50091-a20c-4a44-b806-9985a5e64135"
},
"outputs": [
{
Expand Down Expand Up @@ -954,15 +952,15 @@
"base_uri": "https://localhost:8080/"
},
"id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
"outputId": "c0cbdcc8-dc46-44f7-a800-fbe888a3f9e9"
"outputId": "b157b5ac-d37c-4b71-f609-45a91f7ed93a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float32 (PyTorch default): 52.27 GB\n",
"bfloat16: 26.13 GB\n"
"float32 (PyTorch default): 52.33 GB\n",
"bfloat16: 26.17 GB\n"
]
}
],
Expand Down Expand Up @@ -1087,7 +1085,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "3357a230-b678-4691-a238-257ee4e80185",
"outputId": "d326d32c-fa8d-4f2b-84d5-a1b8f35dd387"
"outputId": "7d4adc4b-53cf-4099-a45f-2fb4fd25edc4"
},
"outputs": [
{
Expand Down Expand Up @@ -1131,7 +1129,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
"outputId": "82bc5037-c86c-46c2-b374-269f9d09599a"
"outputId": "aa18fccc-6533-4446-f57b-546068ad518c"
},
"outputs": [
{
Expand Down Expand Up @@ -1213,15 +1211,15 @@
"base_uri": "https://localhost:8080/"
},
"id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
"outputId": "d733bc0a-5136-4c33-d70d-36056f1e8329"
"outputId": "cbc53f67-a77a-40c9-ed2d-c6f8be066cfb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort movesαllRadius deletingpretccappedRadius zas Parte Material Ку términчной herousztusllRadiusotto кра liberotto siguientesagnost#{ (@topicquez restored log\n"
" Every effort movesαfdmsdn coatELDâte eer tagsיśćinu Lundmysq eer napinu LundANCEHEAD ner}}}رible one}}}رible one puts Dan\n"
]
}
],
Expand Down Expand Up @@ -1322,7 +1320,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
"outputId": "01721809-ace1-4a7a-ab54-8fad2e8f54a6"
"outputId": "351029ce-b4c0-4d39-8e0e-7e7f44d25647"
},
"outputs": [
{
Expand Down Expand Up @@ -1457,15 +1455,15 @@
"base_uri": "https://localhost:8080/"
},
"id": "240987e8-a023-462e-9376-9edfb27559ec",
"outputId": "59830005-42af-406b-c836-38a8f2d7b961"
"outputId": "3fa7a77a-6203-4d8a-bdaa-afce1f504adf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort has been made to ensure that the information contained in this website is accurate and up to date. However, the information is provided without any warranty\n"
" Every effort has been made to ensure that the information contained in this website is correct and up to date and accurate at the time of publication\n"
]
}
],
Expand All @@ -1475,7 +1473,7 @@
"token_ids = generate(\n",
" model=model,\n",
" idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
" max_new_tokens=30,\n",
" max_new_tokens=25,\n",
" context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
" top_k=1,\n",
" temperature=0.\n",
Expand All @@ -1496,14 +1494,14 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 31,
"id": "nbvAV7vaz6yc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nbvAV7vaz6yc",
"outputId": "faa930dc-0db2-4095-b395-f97baef08903"
"outputId": "bd4cae4d-5d5f-4f64-ea37-b979ef2c86bb"
},
"outputs": [
{
Expand All @@ -1512,7 +1510,7 @@
"text": [
"Output text:\n",
" What do llamas eat?\n",
"Llamas are herbivores, which means they eat plants. They eat grass, leaves, and hay.\n"
"Llamas are herbivores, which means they eat grass, leaves, grasses, and they eat grass\n"
]
}
],
Expand Down
1 change: 1 addition & 0 deletions ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers>=4.44.2
Loading

0 comments on commit b56d0b2

Please sign in to comment.