From b56d0b29425d4dd1e2a20138a6a4659e6daea2b4 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Wed, 25 Sep 2024 19:40:36 -0500 Subject: [PATCH] Add llama2 unit tests (#372) * add llama2 unit tests * update * updates * updates * update file path * update requirements file * rmsnorm test * update --- .github/workflows/basic-tests-linux.yml | 2 + .github/workflows/basic-tests-macos.yml | 2 + .github/workflows/basic-tests-old-pytorch.yml | 2 + .github/workflows/basic-tests-windows.yml | 2 + .../converting-gpt-to-llama2.ipynb | 84 +++++++------ .../tests/test-requirements-extra.txt | 1 + ch05/07_gpt_to_llama/tests/tests.py | 114 ++++++++++++++++++ 7 files changed, 164 insertions(+), 43 deletions(-) create mode 100644 ch05/07_gpt_to_llama/tests/test-requirements-extra.txt create mode 100644 ch05/07_gpt_to_llama/tests/tests.py diff --git a/.github/workflows/basic-tests-linux.yml b/.github/workflows/basic-tests-linux.yml index 8d70be75..319dd9c6 100644 --- a/.github/workflows/basic-tests-linux.yml +++ b/.github/workflows/basic-tests-linux.yml @@ -35,12 +35,14 @@ jobs: python -m pip install --upgrade pip pip install pytest nbval if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt - name: Test Selected Python Scripts run: | pytest setup/02_installing-python-libraries/tests.py pytest ch04/01_main-chapter-code/tests.py pytest ch05/01_main-chapter-code/tests.py + pytest ch05/07_gpt_to_llama/tests/tests.py pytest ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks diff --git a/.github/workflows/basic-tests-macos.yml b/.github/workflows/basic-tests-macos.yml index 25f7d6e3..00578e58 100644 --- a/.github/workflows/basic-tests-macos.yml +++ b/.github/workflows/basic-tests-macos.yml @@ -35,12 +35,14 @@ jobs: python -m pip install --upgrade pip pip install pytest nbval if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt - name: Test Selected Python Scripts run: | pytest setup/02_installing-python-libraries/tests.py pytest ch04/01_main-chapter-code/tests.py pytest ch05/01_main-chapter-code/tests.py + pytest ch05/07_gpt_to_llama/tests/tests.py pytest ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks diff --git a/.github/workflows/basic-tests-old-pytorch.yml b/.github/workflows/basic-tests-old-pytorch.yml index 43195899..3a5a618f 100644 --- a/.github/workflows/basic-tests-old-pytorch.yml +++ b/.github/workflows/basic-tests-old-pytorch.yml @@ -39,12 +39,14 @@ jobs: pip install pytest nbval if [ -f requirements.txt ]; then pip install -r requirements.txt; fi pip install torch==${{ matrix.pytorch-version }} + pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt - name: Test Selected Python Scripts run: | pytest setup/02_installing-python-libraries/tests.py pytest ch04/01_main-chapter-code/tests.py pytest ch05/01_main-chapter-code/tests.py + pytest ch05/07_gpt_to_llama/tests/tests.py pytest ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks diff --git a/.github/workflows/basic-tests-windows.yml b/.github/workflows/basic-tests-windows.yml index a09588db..f286156e 100644 --- a/.github/workflows/basic-tests-windows.yml +++ b/.github/workflows/basic-tests-windows.yml @@ -38,6 +38,7 @@ jobs: pip install pytest nbval if [ -f requirements.txt ]; then pip install -r requirements.txt; fi pip install matplotlib==3.9.0 + pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt - name: Test Selected Python Scripts shell: bash @@ -45,6 +46,7 @@ jobs: pytest setup/02_installing-python-libraries/tests.py pytest ch04/01_main-chapter-code/tests.py pytest ch05/01_main-chapter-code/tests.py + pytest ch05/07_gpt_to_llama/tests/tests.py pytest ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks diff --git a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb index 5cb493d9..2b537eea 100644 --- a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb +++ b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb @@ -76,7 +76,7 @@ "base_uri": "https://localhost:8080/" }, "id": "34a9a440-84c2-42cc-808b-38677cb6af8a", - "outputId": "d0fc89be-74a3-40d0-bc4d-7f6f1febf2cd" + "outputId": "7ce8fe41-1c24-4f0b-a8d9-352b4af1b46b" }, "outputs": [ { @@ -84,7 +84,7 @@ "output_type": "stream", "text": [ "huggingface_hub version: 0.24.7\n", - "sentencepiece version: 0.1.99\n", + "sentencepiece version: 0.2.0\n", "torch version: 2.4.1+cu121\n" ] } @@ -421,41 +421,39 @@ " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", " # Compute the inverse frequencies\n", - " inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))\n", + " inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", "\n", " # Generate position indices\n", " positions = torch.arange(context_length)\n", "\n", - " # Compute the angles using inverse frequencies and positions\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, emb_dim // 2)\n", + " # Compute the angles\n", + " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", "\n", - " # Precompute sine and cosine of the angles\n", - " sin = torch.sin(angles) # Shape: (context_length, emb_dim // 2)\n", - " cos = torch.cos(angles) # Shape: (context_length, emb_dim // 2)\n", + " # Expand angles to match the head_dim\n", + " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", "\n", - " return sin, cos\n", + " # Precompute sine and cosine\n", + " cos = torch.cos(angles)\n", + " sin = torch.sin(angles)\n", "\n", + " return cos, sin\n", "\n", - "def compute_rope(x, sin, cos):\n", + "def compute_rope(x, cos, sin):\n", " # x: (batch_size, num_heads, seq_len, head_dim)\n", " batch_size, num_heads, seq_len, head_dim = x.shape\n", " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", "\n", - " # Split x into even and odd parts\n", - " x1 = x[..., ::2] # Shape: (batch_size, num_heads, seq_len, head_dim // 2)\n", - " x2 = x[..., 1::2]\n", + " # Split x into first half and second half\n", + " x1 = x[..., : head_dim // 2] # First half\n", + " x2 = x[..., head_dim // 2 :] # Second half\n", "\n", - " # Ensure sin and cos have correct shapes\n", - " sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)\n", - " cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)\n", + " # Adjust sin and cos shapes\n", + " cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n", + " sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n", "\n", " # Apply the rotary transformation\n", - " x_rotated_0 = x1 * cos - x2 * sin\n", - " x_rotated_1 = x1 * sin + x2 * cos\n", - "\n", - " # Interleave x_rotated_0 and x_rotated_1\n", - " x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1)\n", - " x_rotated = x_rotated.flatten(-2)\n", + " rotated = torch.cat((-x2, x1), dim=-1)\n", + " x_rotated = (x * cos) + (rotated * sin)\n", "\n", " return x_rotated.to(dtype=x.dtype)" ] @@ -486,7 +484,7 @@ "head_dim = 16\n", "\n", "# Instantiate RoPE parameters\n", - "sin, cos = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n", + "cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n", "\n", "# Dummy query and key tensors\n", "torch.manual_seed(123)\n", @@ -494,8 +492,8 @@ "keys = torch.randn(batch_size, context_len, num_heads, head_dim)\n", "\n", "# Apply rotary position embeddings\n", - "queries_rot = compute_rope(queries, sin, cos)\n", - "keys_rot = compute_rope(keys, sin, cos)" + "queries_rot = compute_rope(queries, cos, sin)\n", + "keys_rot = compute_rope(keys, cos, sin)" ] }, { @@ -554,9 +552,9 @@ " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", "\n", " ################################### NEW ###################################\n", - " sin, cos = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n", - " self.register_buffer(\"sin\", sin)\n", + " cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n", " self.register_buffer(\"cos\", cos)\n", + " self.register_buffer(\"sin\", sin)\n", " ###########################################################################\n", "\n", "\n", @@ -736,7 +734,7 @@ "cell_type": "markdown", "id": "ba5d991a-559b-47be-96f4-31b881ab2da8", "metadata": { - "id": "aa79780d-74a8-4ee0-934a-9ad63205a02e" + "id": "ba5d991a-559b-47be-96f4-31b881ab2da8" }, "source": [ "- As you may recall from [chapter 5](../01_main-chapter-code/ch05.ipynb), the `TransformerBlock` is a repeated block within the main model\n", @@ -918,7 +916,7 @@ "base_uri": "https://localhost:8080/" }, "id": "6079f747-8f20-4c6b-8d38-7156f1101729", - "outputId": "78ab929e-ac78-4b16-ddb1-704d45ee69a8" + "outputId": "1ca50091-a20c-4a44-b806-9985a5e64135" }, "outputs": [ { @@ -954,15 +952,15 @@ "base_uri": "https://localhost:8080/" }, "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a", - "outputId": "c0cbdcc8-dc46-44f7-a800-fbe888a3f9e9" + "outputId": "b157b5ac-d37c-4b71-f609-45a91f7ed93a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "float32 (PyTorch default): 52.27 GB\n", - "bfloat16: 26.13 GB\n" + "float32 (PyTorch default): 52.33 GB\n", + "bfloat16: 26.17 GB\n" ] } ], @@ -1087,7 +1085,7 @@ "base_uri": "https://localhost:8080/" }, "id": "3357a230-b678-4691-a238-257ee4e80185", - "outputId": "d326d32c-fa8d-4f2b-84d5-a1b8f35dd387" + "outputId": "7d4adc4b-53cf-4099-a45f-2fb4fd25edc4" }, "outputs": [ { @@ -1131,7 +1129,7 @@ "base_uri": "https://localhost:8080/" }, "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", - "outputId": "82bc5037-c86c-46c2-b374-269f9d09599a" + "outputId": "aa18fccc-6533-4446-f57b-546068ad518c" }, "outputs": [ { @@ -1213,7 +1211,7 @@ "base_uri": "https://localhost:8080/" }, "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e", - "outputId": "d733bc0a-5136-4c33-d70d-36056f1e8329" + "outputId": "cbc53f67-a77a-40c9-ed2d-c6f8be066cfb" }, "outputs": [ { @@ -1221,7 +1219,7 @@ "output_type": "stream", "text": [ "Output text:\n", - " Every effort movesαllRadius deletingpretccappedRadius zas Parte Material Ку términчной herousztusllRadiusotto кра liberotto siguientesagnost#{ (@topicquez restored log\n" + " Every effort movesαfdmsdn coatELDâte eer tagsיśćinu Lundmysq eer napinu LundANCEHEAD ner}}}رible one}}}رible one puts Dan\n" ] } ], @@ -1322,7 +1320,7 @@ "base_uri": "https://localhost:8080/" }, "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49", - "outputId": "01721809-ace1-4a7a-ab54-8fad2e8f54a6" + "outputId": "351029ce-b4c0-4d39-8e0e-7e7f44d25647" }, "outputs": [ { @@ -1457,7 +1455,7 @@ "base_uri": "https://localhost:8080/" }, "id": "240987e8-a023-462e-9376-9edfb27559ec", - "outputId": "59830005-42af-406b-c836-38a8f2d7b961" + "outputId": "3fa7a77a-6203-4d8a-bdaa-afce1f504adf" }, "outputs": [ { @@ -1465,7 +1463,7 @@ "output_type": "stream", "text": [ "Output text:\n", - " Every effort has been made to ensure that the information contained in this website is accurate and up to date. However, the information is provided without any warranty\n" + " Every effort has been made to ensure that the information contained in this website is correct and up to date and accurate at the time of publication\n" ] } ], @@ -1475,7 +1473,7 @@ "token_ids = generate(\n", " model=model,\n", " idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n", - " max_new_tokens=30,\n", + " max_new_tokens=25,\n", " context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n", " top_k=1,\n", " temperature=0.\n", @@ -1496,14 +1494,14 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 31, "id": "nbvAV7vaz6yc", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nbvAV7vaz6yc", - "outputId": "faa930dc-0db2-4095-b395-f97baef08903" + "outputId": "bd4cae4d-5d5f-4f64-ea37-b979ef2c86bb" }, "outputs": [ { @@ -1512,7 +1510,7 @@ "text": [ "Output text:\n", " What do llamas eat?\n", - "Llamas are herbivores, which means they eat plants. They eat grass, leaves, and hay.\n" + "Llamas are herbivores, which means they eat grass, leaves, grasses, and they eat grass\n" ] } ], diff --git a/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt b/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt new file mode 100644 index 00000000..8828ccea --- /dev/null +++ b/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt @@ -0,0 +1 @@ +transformers>=4.44.2 \ No newline at end of file diff --git a/ch05/07_gpt_to_llama/tests/tests.py b/ch05/07_gpt_to_llama/tests/tests.py new file mode 100644 index 00000000..99d7b3fe --- /dev/null +++ b/ch05/07_gpt_to_llama/tests/tests.py @@ -0,0 +1,114 @@ +import io +import os +import sys +import types +import nbformat +import torch +import pytest +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb + + +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# File for internal use (unit tests) + + +@pytest.fixture(scope="module") +def notebook(): + def import_definitions_from_notebook(fullname, names): + # Get the directory of the current test file + current_dir = os.path.dirname(__file__) + path = os.path.join(current_dir, "..", fullname + ".ipynb") + path = os.path.normpath(path) + + # Load the notebook + if not os.path.exists(path): + raise FileNotFoundError(f"Notebook file not found at: {path}") + + with io.open(path, "r", encoding="utf-8") as f: + nb = nbformat.read(f, as_version=4) + + # Create a module to store the imported functions and classes + mod = types.ModuleType(fullname) + sys.modules[fullname] = mod + + # Go through the notebook cells and only execute function or class definitions + for cell in nb.cells: + if cell.cell_type == "code": + cell_code = cell.source + for name in names: + # Check for function or class definitions + if f"def {name}" in cell_code or f"class {name}" in cell_code: + exec(cell_code, mod.__dict__) + return mod + + # Specify the notebook name and functions/classes to import + fullname = "converting-gpt-to-llama2" + names = ["precompute_rope_params", "compute_rope", "SiLU", "RMSNorm"] + + # Import the required functions and classes from the notebook + return import_definitions_from_notebook(fullname, names) + + +@pytest.fixture(autouse=True) +def set_seed(): + torch.manual_seed(123) + + +def test_rope(notebook): + # Settings + batch_size = 1 + context_len = 5 + num_heads = 4 + head_dim = 16 + + # Instantiate RoPE parameters + cos, sin = notebook.precompute_rope_params(head_dim=head_dim, context_length=context_len) + + # Dummy query and key tensors + queries = torch.randn(batch_size, num_heads, context_len, head_dim) + keys = torch.randn(batch_size, num_heads, context_len, head_dim) + + # Apply rotary position embeddings + queries_rot = notebook.compute_rope(queries, cos, sin) + keys_rot = notebook.compute_rope(keys, cos, sin) + + class RoPEConfig: + rope_type = "default" + rope_scaling = None + factor = 1.0 + dim: int = head_dim + rope_theta = 10000 + max_position_embeddings: int = 4096 + hidden_size = head_dim * num_heads + num_attention_heads = num_heads + + config = RoPEConfig() + + rot_emb = LlamaRotaryEmbedding(config=config) + position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) + ref_cos, ref_sin = rot_emb(queries, position_ids) + ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) + + torch.testing.assert_close(sin, ref_sin.squeeze(0)) + torch.testing.assert_close(cos, ref_cos.squeeze(0)) + torch.testing.assert_close(keys_rot, ref_keys_rot) + torch.testing.assert_close(queries_rot, ref_queries_rot) + + +def test_silu(notebook): + example_batch = torch.randn(2, 3, 4) + silu = notebook.SiLU() + assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch)) + + +@pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer") +def test_rmsnorm(notebook): + example_batch = torch.randn(2, 3, 4) + rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1]) + rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6) + + assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))