Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama2 unit tests #372

Merged
merged 8 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ jobs:
python -m pip install --upgrade pip
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt

- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py

- name: Validate Selected Jupyter Notebooks
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests-macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ jobs:
python -m pip install --upgrade pip
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt

- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py

- name: Validate Selected Jupyter Notebooks
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests-old-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ jobs:
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install torch==${{ matrix.pytorch-version }}
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt

- name: Test Selected Python Scripts
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py

- name: Validate Selected Jupyter Notebooks
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ jobs:
pip install pytest nbval
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install matplotlib==3.9.0
pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt

- name: Test Selected Python Scripts
shell: bash
run: |
pytest setup/02_installing-python-libraries/tests.py
pytest ch04/01_main-chapter-code/tests.py
pytest ch05/01_main-chapter-code/tests.py
pytest ch05/07_gpt_to_llama/tests/tests.py
pytest ch06/01_main-chapter-code/tests.py

- name: Validate Selected Jupyter Notebooks
Expand Down
84 changes: 41 additions & 43 deletions ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@
"base_uri": "https://localhost:8080/"
},
"id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
"outputId": "d0fc89be-74a3-40d0-bc4d-7f6f1febf2cd"
"outputId": "7ce8fe41-1c24-4f0b-a8d9-352b4af1b46b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.24.7\n",
"sentencepiece version: 0.1.99\n",
"sentencepiece version: 0.2.0\n",
"torch version: 2.4.1+cu121\n"
]
}
Expand Down Expand Up @@ -421,41 +421,39 @@
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))\n",
" inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length)\n",
"\n",
" # Compute the angles using inverse frequencies and positions\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, emb_dim // 2)\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Precompute sine and cosine of the angles\n",
" sin = torch.sin(angles) # Shape: (context_length, emb_dim // 2)\n",
" cos = torch.cos(angles) # Shape: (context_length, emb_dim // 2)\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",
"\n",
" return sin, cos\n",
" # Precompute sine and cosine\n",
" cos = torch.cos(angles)\n",
" sin = torch.sin(angles)\n",
"\n",
" return cos, sin\n",
"\n",
"def compute_rope(x, sin, cos):\n",
"def compute_rope(x, cos, sin):\n",
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
"\n",
" # Split x into even and odd parts\n",
" x1 = x[..., ::2] # Shape: (batch_size, num_heads, seq_len, head_dim // 2)\n",
" x2 = x[..., 1::2]\n",
" # Split x into first half and second half\n",
" x1 = x[..., : head_dim // 2] # First half\n",
" x2 = x[..., head_dim // 2 :] # Second half\n",
"\n",
" # Ensure sin and cos have correct shapes\n",
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)\n",
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
" # Adjust sin and cos shapes\n",
" cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
" sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
"\n",
" # Apply the rotary transformation\n",
" x_rotated_0 = x1 * cos - x2 * sin\n",
" x_rotated_1 = x1 * sin + x2 * cos\n",
"\n",
" # Interleave x_rotated_0 and x_rotated_1\n",
" x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1)\n",
" x_rotated = x_rotated.flatten(-2)\n",
" rotated = torch.cat((-x2, x1), dim=-1)\n",
" x_rotated = (x * cos) + (rotated * sin)\n",
"\n",
" return x_rotated.to(dtype=x.dtype)"
]
Expand Down Expand Up @@ -486,16 +484,16 @@
"head_dim = 16\n",
"\n",
"# Instantiate RoPE parameters\n",
"sin, cos = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
"cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
"\n",
"# Dummy query and key tensors\n",
"torch.manual_seed(123)\n",
"queries = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
"keys = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
"\n",
"# Apply rotary position embeddings\n",
"queries_rot = compute_rope(queries, sin, cos)\n",
"keys_rot = compute_rope(keys, sin, cos)"
"queries_rot = compute_rope(queries, cos, sin)\n",
"keys_rot = compute_rope(keys, cos, sin)"
]
},
{
Expand Down Expand Up @@ -554,9 +552,9 @@
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" ################################### NEW ###################################\n",
" sin, cos = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
" self.register_buffer(\"sin\", sin)\n",
" cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
" self.register_buffer(\"cos\", cos)\n",
" self.register_buffer(\"sin\", sin)\n",
" ###########################################################################\n",
"\n",
"\n",
Expand Down Expand Up @@ -736,7 +734,7 @@
"cell_type": "markdown",
"id": "ba5d991a-559b-47be-96f4-31b881ab2da8",
"metadata": {
"id": "aa79780d-74a8-4ee0-934a-9ad63205a02e"
"id": "ba5d991a-559b-47be-96f4-31b881ab2da8"
},
"source": [
"- As you may recall from [chapter 5](../01_main-chapter-code/ch05.ipynb), the `TransformerBlock` is a repeated block within the main model\n",
Expand Down Expand Up @@ -918,7 +916,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "6079f747-8f20-4c6b-8d38-7156f1101729",
"outputId": "78ab929e-ac78-4b16-ddb1-704d45ee69a8"
"outputId": "1ca50091-a20c-4a44-b806-9985a5e64135"
},
"outputs": [
{
Expand Down Expand Up @@ -954,15 +952,15 @@
"base_uri": "https://localhost:8080/"
},
"id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
"outputId": "c0cbdcc8-dc46-44f7-a800-fbe888a3f9e9"
"outputId": "b157b5ac-d37c-4b71-f609-45a91f7ed93a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"float32 (PyTorch default): 52.27 GB\n",
"bfloat16: 26.13 GB\n"
"float32 (PyTorch default): 52.33 GB\n",
"bfloat16: 26.17 GB\n"
]
}
],
Expand Down Expand Up @@ -1087,7 +1085,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "3357a230-b678-4691-a238-257ee4e80185",
"outputId": "d326d32c-fa8d-4f2b-84d5-a1b8f35dd387"
"outputId": "7d4adc4b-53cf-4099-a45f-2fb4fd25edc4"
},
"outputs": [
{
Expand Down Expand Up @@ -1131,7 +1129,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
"outputId": "82bc5037-c86c-46c2-b374-269f9d09599a"
"outputId": "aa18fccc-6533-4446-f57b-546068ad518c"
},
"outputs": [
{
Expand Down Expand Up @@ -1213,15 +1211,15 @@
"base_uri": "https://localhost:8080/"
},
"id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
"outputId": "d733bc0a-5136-4c33-d70d-36056f1e8329"
"outputId": "cbc53f67-a77a-40c9-ed2d-c6f8be066cfb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort movesαllRadius deletingpretccappedRadius zas Parte Material Ку términчной herousztusllRadiusotto кра liberotto siguientesagnost#{ (@topicquez restored log\n"
" Every effort movesαfdmsdn coatELDâte eer tagsיśćinu Lundmysq eer napinu LundANCEHEAD ner}}}رible one}}}رible one puts Dan\n"
]
}
],
Expand Down Expand Up @@ -1322,7 +1320,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
"outputId": "01721809-ace1-4a7a-ab54-8fad2e8f54a6"
"outputId": "351029ce-b4c0-4d39-8e0e-7e7f44d25647"
},
"outputs": [
{
Expand Down Expand Up @@ -1457,15 +1455,15 @@
"base_uri": "https://localhost:8080/"
},
"id": "240987e8-a023-462e-9376-9edfb27559ec",
"outputId": "59830005-42af-406b-c836-38a8f2d7b961"
"outputId": "3fa7a77a-6203-4d8a-bdaa-afce1f504adf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort has been made to ensure that the information contained in this website is accurate and up to date. However, the information is provided without any warranty\n"
" Every effort has been made to ensure that the information contained in this website is correct and up to date and accurate at the time of publication\n"
]
}
],
Expand All @@ -1475,7 +1473,7 @@
"token_ids = generate(\n",
" model=model,\n",
" idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
" max_new_tokens=30,\n",
" max_new_tokens=25,\n",
" context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
" top_k=1,\n",
" temperature=0.\n",
Expand All @@ -1496,14 +1494,14 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 31,
"id": "nbvAV7vaz6yc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nbvAV7vaz6yc",
"outputId": "faa930dc-0db2-4095-b395-f97baef08903"
"outputId": "bd4cae4d-5d5f-4f64-ea37-b979ef2c86bb"
},
"outputs": [
{
Expand All @@ -1512,7 +1510,7 @@
"text": [
"Output text:\n",
" What do llamas eat?\n",
"Llamas are herbivores, which means they eat plants. They eat grass, leaves, and hay.\n"
"Llamas are herbivores, which means they eat grass, leaves, grasses, and they eat grass\n"
]
}
],
Expand Down
1 change: 1 addition & 0 deletions ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers>=4.44.2
Loading