From 0467c8289b3713a35e63d0fd1289ef2a3fbb2cf2 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Mon, 23 Sep 2024 05:34:06 -0700 Subject: [PATCH] GPT to Llama (#368) * GPT to Llama * fix urls --- .gitignore | 2 + README.md | 1 + ch05/07_gpt_to_llama/README.md | 7 + .../converting-gpt-to-llama2.ipynb | 1568 +++++++++++++++++ ch05/07_gpt_to_llama/previous_chapters.py | 63 + ch05/07_gpt_to_llama/requirements-extra.txt | 2 + ch05/README.md | 1 + 7 files changed, 1644 insertions(+) create mode 100644 ch05/07_gpt_to_llama/README.md create mode 100644 ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb create mode 100644 ch05/07_gpt_to_llama/previous_chapters.py create mode 100644 ch05/07_gpt_to_llama/requirements-extra.txt diff --git a/.gitignore b/.gitignore index 71873eca..e51333c4 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,8 @@ ch05/01_main-chapter-code/model.pth ch05/01_main-chapter-code/model_and_optimizer.pth ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints ch05/06_user_interface/gpt2 +ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b +ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b-chat ch06/01_main-chapter-code/gpt2 ch06/02_bonus_additional-experiments/gpt2 diff --git a/README.md b/README.md index f0b8eb3a..853b143d 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,7 @@ Several folders contain optional materials as a bonus for interested readers: - [Adding Bells and Whistles to the Training Loop](ch05/04_learning_rate_schedulers) - [Optimizing Hyperparameters for Pretraining](ch05/05_bonus_hparam_tuning) - [Building a User Interface to Interact With the Pretrained LLM](ch05/06_user_interface) + - [Converting GPT to Llama](ch05/07_gpt_to_llama) - **Chapter 6:** - [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments) - [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification) diff --git a/ch05/07_gpt_to_llama/README.md b/ch05/07_gpt_to_llama/README.md new file mode 100644 index 00000000..280d43e5 --- /dev/null +++ b/ch05/07_gpt_to_llama/README.md @@ -0,0 +1,7 @@ +# Converting GPT to Llama + + + +This folder contains code for converting the GPT implementation from chapter 4 and 5 to Meta AI's Llama architecture: + +- [converting-gpt-to-llama2.ipynb](converting-gpt-to-llama2.ipynb): contains code to convert GPT to Llama 2 7B step by step and loads pretrained weights from Meta AI \ No newline at end of file diff --git a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb new file mode 100644 index 00000000..ac442b36 --- /dev/null +++ b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb @@ -0,0 +1,1568 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0_xya1nyDHfY", + "metadata": { + "id": "0_xya1nyDHfY" + }, + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", + "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", + "
\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "l62zIRRSBy_R", + "metadata": { + "id": "l62zIRRSBy_R" + }, + "source": [ + "# Converting a From-Scratch GPT Architecture to Llama 2" + ] + }, + { + "cell_type": "markdown", + "id": "aFmxTQbwCUMl", + "metadata": { + "id": "aFmxTQbwCUMl" + }, + "source": [ + "- In this notebook, we convert the original GPT and GPT-2 architecture into a Llama 2 model step by step\n", + "- Why not Llama 1 or Llama 3?\n", + " - The Llama 1 architecture is similar to Llama 2, except that Llama 2 has a larger context window (which is nice); the Llama 1 weights are not readily available and have more usage restrictions, so it makes more sense to focus on Llama 2\n", + " - Regarding Llama 3, I will share a separate notebook to convert Llama 2 to Llama 3 (there are only a few small additional changes)\n", + "- The explanations are purposefully kept minimal in this notebook not to bloat it unnecessarily and focus on the main code\n", + "- For more information, please see the Llama 2 paper: [Llama 2: Open Foundation and Fine-Tuned Chat Models (2023)](https://arxiv.org/abs/2307.09288)" + ] + }, + { + "cell_type": "markdown", + "id": "ohhMKUWvGm9z", + "metadata": { + "id": "ohhMKUWvGm9z" + }, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "id": "JBpQwU89ETA1", + "metadata": { + "id": "JBpQwU89ETA1" + }, + "source": [ + "- Packages that are being used in this notebook:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "34a9a440-84c2-42cc-808b-38677cb6af8a", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "34a9a440-84c2-42cc-808b-38677cb6af8a", + "outputId": "d0fc89be-74a3-40d0-bc4d-7f6f1febf2cd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface_hub version: 0.24.7\n", + "sentencepiece version: 0.1.99\n", + "torch version: 2.4.1+cu121\n" + ] + } + ], + "source": [ + "from importlib.metadata import version\n", + "\n", + "pkgs = [\n", + " \"huggingface_hub\", # to download pretrained weights\n", + " \"sentencepiece\", # to implement the tokenizer\n", + " \"torch\", # to implement the model\n", + "]\n", + "for p in pkgs:\n", + " print(f\"{p} version: {version(p)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "UJJneXpTEg4W", + "metadata": { + "id": "UJJneXpTEg4W" + }, + "source": [ + "## 1. Convert the GPT model implementation step by step" + ] + }, + { + "cell_type": "markdown", + "id": "v1zpfX2GHBKa", + "metadata": { + "id": "v1zpfX2GHBKa" + }, + "source": [ + "- In this section, we go through the GPT model code from [chapter 4](../../ch04/01_main-chapter-code/ch04.ipynb) and modify it step by step to implement the Llama 2 architecture\n", + "- Later, we load the original Llama 2 weights shared by Meta AI" + ] + }, + { + "cell_type": "markdown", + "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f", + "metadata": { + "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f" + }, + "source": [ + "### 1.2 Replace LayerNorm with RMSNorm layer" + ] + }, + { + "cell_type": "markdown", + "id": "f8b27fc8-23a1-4e0e-a1ea-792e0428e5e6", + "metadata": { + "id": "f8b27fc8-23a1-4e0e-a1ea-792e0428e5e6" + }, + "source": [ + "- First, we replace LayerNorm by Root Mean Square Layer Normalization (RMSNorm)\n", + "- LayerNorm normalizes inputs using mean and variance, while RMSNorm uses only the root mean square, which improves computational efficiency\n", + "- The RMSNorm operation is as follows, where $x$ is the input $\\gamma$ is a trainable parameter (vector), and $\\epsilon$ is a small constant to avoid zero-division errors:\n", + "\n", + "$$y = \\frac{x}{\\sqrt{\\text{RMS}[x]} + \\epsilon} * \\gamma$$\n", + "\n", + "- For more details, please see the paper [Root Mean Square Layer Normalization (2019)](https://arxiv.org/abs/1910.07467)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d7094381-9499-4e9e-93f9-b79470da3771", + "metadata": { + "id": "d7094381-9499-4e9e-93f9-b79470da3771" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "#####################################\n", + "# Chapter 4\n", + "#####################################\n", + "\n", + "# class LayerNorm(nn.Module):\n", + "# def __init__(self, emb_dim):\n", + "# super().__init__()\n", + "# self.eps = 1e-5\n", + "# self.scale = nn.Parameter(torch.ones(emb_dim))\n", + "# self.shift = nn.Parameter(torch.zeros(emb_dim))\n", + "\n", + "# def forward(self, x):\n", + "# mean = x.mean(dim=-1, keepdim=True)\n", + "# var = x.var(dim=-1, keepdim=True, unbiased=False)\n", + "# norm_x = (x - mean) / torch.sqrt(var + self.eps)\n", + "# return self.scale * norm_x + self.shift\n", + "\n", + "\n", + "class RMSNorm(nn.Module):\n", + " def __init__(self, emb_dim, eps=1e-6):\n", + " super().__init__()\n", + " self.eps = eps\n", + " self.emb_dim = emb_dim\n", + " self.weight = nn.Parameter(torch.ones(emb_dim)).float()\n", + "\n", + " def forward(self, x):\n", + " means = x.pow(2).mean(dim=-1, keepdim=True)\n", + " x_normed = x * torch.rsqrt(means + self.eps)\n", + " return (x_normed * self.weight).to(dtype=x.dtype)" + ] + }, + { + "cell_type": "markdown", + "id": "mtWC8DOmIu0F", + "metadata": { + "id": "mtWC8DOmIu0F" + }, + "source": [ + "- The following code cell checks that this implementation works the same as PyTorch's built-in implementation:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e41ade7a-bf06-48b1-8b7e-0e4037d5753f", + "metadata": { + "id": "e41ade7a-bf06-48b1-8b7e-0e4037d5753f" + }, + "outputs": [], + "source": [ + "torch.manual_seed(123)\n", + "\n", + "example_batch = torch.randn(2, 3, 4)\n", + "\n", + "rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])\n", + "rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)\n", + "\n", + "assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))" + ] + }, + { + "cell_type": "markdown", + "id": "5eb81f83-c38c-46a4-b763-aa630a32e357", + "metadata": { + "id": "5eb81f83-c38c-46a4-b763-aa630a32e357" + }, + "source": [ + "## Replace GELU with SiLU activation" + ] + }, + { + "cell_type": "markdown", + "id": "0b8aa702-f118-4ff6-9135-90725ec8756c", + "metadata": { + "id": "0b8aa702-f118-4ff6-9135-90725ec8756c" + }, + "source": [ + "- Llama uses the SiLU activation function (instead of GELU), which is also known as the Swish function:\n", + "\n", + "$$\n", + "\\text{silu}(x) = x \\cdot \\sigma(x), \\quad \\text{where} \\quad \\sigma(x) \\text{ is the logistic sigmoid.}\n", + "$$\n", + "\n", + "- For more information, see the SiLU paper: [Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning (2017)](https://arxiv.org/abs/1702.03118)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a74f3757-c634-4a3a-a8f3-6334cde454fe", + "metadata": { + "id": "a74f3757-c634-4a3a-a8f3-6334cde454fe" + }, + "outputs": [], + "source": [ + "#####################################\n", + "# Chapter 4\n", + "#####################################\n", + "\n", + "# class GELU(nn.Module):\n", + "# def __init__(self):\n", + "# super().__init__()\n", + "\n", + "# def forward(self, x):\n", + "# return 0.5 * x * (1 + torch.tanh(\n", + "# torch.sqrt(torch.tensor(2.0 / torch.pi)) *\n", + "# (x + 0.044715 * torch.pow(x, 3))\n", + "# ))\n", + "\n", + "\n", + "class SiLU(nn.Module):\n", + " def __init__(self):\n", + " super(SiLU, self).__init__()\n", + "\n", + " def forward(self, x):\n", + " return x * torch.sigmoid(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "72ecbe2e-b6b7-4319-972b-1a7fefa3368c", + "metadata": { + "id": "72ecbe2e-b6b7-4319-972b-1a7fefa3368c" + }, + "outputs": [], + "source": [ + "silu = SiLU()\n", + "\n", + "assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))" + ] + }, + { + "cell_type": "markdown", + "id": "4f9b5167-1da9-46c8-9964-8036b3b1deb9", + "metadata": { + "id": "4f9b5167-1da9-46c8-9964-8036b3b1deb9" + }, + "source": [ + "## Update the FeedForward module" + ] + }, + { + "cell_type": "markdown", + "id": "3a381e7a-b807-472e-91c9-3e4e3fc5ad91", + "metadata": { + "id": "3a381e7a-b807-472e-91c9-3e4e3fc5ad91" + }, + "source": [ + "- In fact, Llama uses a \"Gates Linear Unit\" (GLU) variant of SiLU called SwiGLU, which essentially results in a slightly differently structured `FeedForward` module\n", + "- SwiGLU uses a gating mechanism in the feedforward layer, with the formula:\n", + "\n", + "$$\\text{SwiGLU}(x) = (\\text{Linear}_1(x) * \\text{SiLU}(\\text{Linear}_2(x)))$$\n", + "\n", + "- Here, $\\text{Linear}_1$ and $\\text{Linear}_2$ are two linear layers, and $*$ denotes element-wise multiplication\n", + "- The third linear layer, $\\text{Linear}_3$, is applied after this gated activation\n", + "\n", + "- For more information, see SwiGLU paper: [GLU Variants Improve Transformer (2020)](https://arxiv.org/abs/2002.05202)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d25fbe3d-b7c9-4772-ad67-bc0527e4e20a", + "metadata": { + "id": "d25fbe3d-b7c9-4772-ad67-bc0527e4e20a" + }, + "outputs": [], + "source": [ + "#####################################\n", + "# Chapter 4\n", + "#####################################\n", + "# class FeedForward(nn.Module):\n", + "# def __init__(self, cfg):\n", + "# super().__init__()\n", + "# self.layers = nn.Sequential(\n", + "# nn.Linear(cfg[\"emb_dim\"], 4 * cfg[\"emb_dim\"]),\n", + "# GELU(),\n", + "# nn.Linear(4 * cfg[\"emb_dim\"], cfg[\"emb_dim\"]),\n", + "# )\n", + "\n", + "# def forward(self, x):\n", + "# return self.layers(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "477568cb-03cd-4510-b663-a42ce3ec64a2", + "metadata": { + "id": "477568cb-03cd-4510-b663-a42ce3ec64a2" + }, + "outputs": [], + "source": [ + "class FeedForward(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.silu = SiLU()\n", + "\n", + " def forward(self, x):\n", + " x_fc1 = self.fc1(x)\n", + " x_fc2 = self.fc2(x)\n", + " x = self.silu(x_fc1) * x_fc2\n", + " return self.fc3(x)" + ] + }, + { + "cell_type": "markdown", + "id": "qcD8LSHNhBRW", + "metadata": { + "id": "qcD8LSHNhBRW" + }, + "source": [ + "- Note that we also added a `dtype=cfg[\"dtype\"]` setting above, which will allow us to load the model directly in lower precision formats later to save memory (versus instantiating it in the original 32-bit precision format and then converting it)\n", + "- We also set `bias=False` since Llama doesn't use any bias units" + ] + }, + { + "cell_type": "markdown", + "id": "f6b7bf4f-99d0-42c1-807c-5074d2cc1949", + "metadata": { + "id": "f6b7bf4f-99d0-42c1-807c-5074d2cc1949" + }, + "source": [ + "## Implement RoPE" + ] + }, + { + "cell_type": "markdown", + "id": "d3487a6f-0373-49d8-b2eb-f8ee05d42884", + "metadata": { + "id": "d3487a6f-0373-49d8-b2eb-f8ee05d42884" + }, + "source": [ + "- In the GPT model, the positional embeddings are implemented as follows:\n", + "\n", + "```python\n", + "self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n", + "```\n", + "\n", + "- Instead of these absolute positional embeddings, Llama uses relative positional embeddings, called rotary position embeddings (RoPE for short)\n", + "- The reference paper for RoPE is [RoFormer: Enhanced Transformer with Rotary Position Embedding (2021)](https://arxiv.org/abs/2104.09864)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a34180fb-448f-44e9-a244-0c736051687b", + "metadata": { + "id": "a34180fb-448f-44e9-a244-0c736051687b" + }, + "outputs": [], + "source": [ + "def precompute_rope_params(head_dim, context_length=4096):\n", + " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", + "\n", + " # Compute the inverse frequencies\n", + " inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2) / head_dim))\n", + "\n", + " # Generate position indices\n", + " positions = torch.arange(context_length)\n", + "\n", + " # Compute the angles using inverse frequencies and positions\n", + " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, emb_dim // 2)\n", + "\n", + " # Precompute sine and cosine of the angles\n", + " sin = torch.sin(angles) # Shape: (context_length, emb_dim // 2)\n", + " cos = torch.cos(angles) # Shape: (context_length, emb_dim // 2)\n", + "\n", + " return sin, cos\n", + "\n", + "\n", + "def compute_rope(x, sin, cos):\n", + " # x: (batch_size, num_heads, seq_len, head_dim)\n", + " batch_size, num_heads, seq_len, head_dim = x.shape\n", + " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", + "\n", + " # Split x into even and odd parts\n", + " x1 = x[..., ::2] # Shape: (batch_size, num_heads, seq_len, head_dim // 2)\n", + " x2 = x[..., 1::2]\n", + "\n", + " # Ensure sin and cos have correct shapes\n", + " sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim // 2)\n", + " cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)\n", + "\n", + " # Apply the rotary transformation\n", + " x_rotated_0 = x1 * cos - x2 * sin\n", + " x_rotated_1 = x1 * sin + x2 * cos\n", + "\n", + " # Interleave x_rotated_0 and x_rotated_1\n", + " x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1)\n", + " x_rotated = x_rotated.flatten(-2)\n", + "\n", + " return x_rotated.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "markdown", + "id": "8e841b8e-75aa-49db-b1a7-d5c2116dc299", + "metadata": { + "id": "8e841b8e-75aa-49db-b1a7-d5c2116dc299" + }, + "source": [ + "- The following is an example of applying RoPE to the `q` and `k` tensors:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8c89f022-7167-4001-8c21-8e012878733f", + "metadata": { + "id": "8c89f022-7167-4001-8c21-8e012878733f" + }, + "outputs": [], + "source": [ + "# Settings\n", + "batch_size = 2\n", + "context_len = 5\n", + "num_heads = 4\n", + "head_dim = 16\n", + "\n", + "# Instantiate RoPE parameters\n", + "sin, cos = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n", + "\n", + "# Dummy query and key tensors\n", + "torch.manual_seed(123)\n", + "queries = torch.randn(batch_size, context_len, num_heads, head_dim)\n", + "keys = torch.randn(batch_size, context_len, num_heads, head_dim)\n", + "\n", + "# Apply rotary position embeddings\n", + "queries_rot = compute_rope(queries, sin, cos)\n", + "keys_rot = compute_rope(keys, sin, cos)" + ] + }, + { + "cell_type": "markdown", + "id": "f78127b0-dda2-4c5a-98dd-bae8f5fe8297", + "metadata": { + "id": "f78127b0-dda2-4c5a-98dd-bae8f5fe8297" + }, + "source": [ + "## Add RoPE to MultiHeadAttention module" + ] + }, + { + "cell_type": "markdown", + "id": "RnmSHROLhhR3", + "metadata": { + "id": "RnmSHROLhhR3" + }, + "source": [ + "- It's important to note that GPT applies the positional embeddings to the inputs, whereas Llama applies rotations to the query and key vectors in the self-attention mechanism itself\n", + "- Here, we modify the `MultiHeadAttention` class with the appropriate RoPE code\n", + "- In addition, we remove the `qkv_bias` option and hardcode the `bias=False` setting\n", + "- Also, we add a dtype setting to be able to instantiate the model with a lower precision later\n", + " - Tip: since the `TransformerBlock's (in the next section) are repeated exactly, we could simplify the code and only initialize the buffers once instead for each `MultiHeadAttention` module; however, we add the precomputed RoPE parameters to the `MultiHeadAttention` class so that it can function as a standalone module" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d81a441e-0b79-4a8b-8291-ea7f55d58c84", + "metadata": { + "id": "d81a441e-0b79-4a8b-8291-ea7f55d58c84" + }, + "outputs": [], + "source": [ + "#####################################\n", + "# Chapter 3\n", + "#####################################\n", + "class MultiHeadAttention(nn.Module):\n", + " def __init__(self, d_in, d_out, context_length, num_heads, dtype=None): # ,dropout, num_heads, qkv_bias=False):\n", + " super().__init__()\n", + " assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n", + "\n", + " self.d_out = d_out\n", + " self.num_heads = num_heads\n", + " self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n", + "\n", + " ################################### NEW ###################################\n", + " # Set bias=False and dtype=dtype for all linear layers below\n", + " ###########################################################################\n", + " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", + " self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", + " self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", + " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) # Linear layer to combine head outputs\n", + " # self.dropout = nn.Dropout(dropout)\n", + " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", + "\n", + " ################################### NEW ###################################\n", + " sin, cos = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n", + " self.register_buffer(\"sin\", sin)\n", + " self.register_buffer(\"cos\", cos)\n", + " ###########################################################################\n", + "\n", + "\n", + " def forward(self, x):\n", + "\n", + " b, num_tokens, d_in = x.shape\n", + "\n", + " keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n", + " queries = self.W_query(x)\n", + " values = self.W_value(x)\n", + "\n", + " # We implicitly split the matrix by adding a `num_heads` dimension\n", + " # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n", + " keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n", + " values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n", + " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n", + "\n", + " # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n", + " keys = keys.transpose(1, 2)\n", + " queries = queries.transpose(1, 2)\n", + " values = values.transpose(1, 2)\n", + "\n", + " ################################### NEW ###################################\n", + " keys = compute_rope(keys, self.sin, self.cos)\n", + " queries = compute_rope(queries, self.sin, self.cos)\n", + " ###########################################################################\n", + "\n", + " # Compute scaled dot-product attention (aka self-attention) with a causal mask\n", + " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", + "\n", + " # Original mask truncated to the number of tokens and converted to boolean\n", + " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", + "\n", + " # Use the mask to fill attention scores\n", + " attn_scores.masked_fill_(mask_bool, -torch.inf)\n", + "\n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", + " # attn_weights = self.dropout(attn_weights)\n", + "\n", + " # Shape: (b, num_tokens, num_heads, head_dim)\n", + " context_vec = (attn_weights @ values).transpose(1, 2)\n", + "\n", + " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", + " context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n", + " context_vec = self.out_proj(context_vec) # optional projection\n", + "\n", + " return context_vec" + ] + }, + { + "cell_type": "markdown", + "id": "-lt9SfnVioB3", + "metadata": { + "id": "-lt9SfnVioB3" + }, + "source": [ + "- Below is an example using the `MultiHeadAttention` module on an example input:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "03f15755-0083-483f-963b-99b599651638", + "metadata": { + "id": "03f15755-0083-483f-963b-99b599651638" + }, + "outputs": [], + "source": [ + "# Settings\n", + "batch_size = 1\n", + "context_len = 100\n", + "max_context_len = 4096\n", + "embed_dim = 128\n", + "num_heads = 4\n", + "\n", + "\n", + "example_batch = torch.randn((batch_size, context_len, embed_dim))\n", + "\n", + "mha = MultiHeadAttention(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " context_length=max_context_len,\n", + " num_heads=num_heads\n", + ")\n", + "\n", + "mha(example_batch)\n", + "\n", + "del mha # delete to safe memory" + ] + }, + { + "cell_type": "markdown", + "id": "e5a1a272-a038-4b8f-aaaa-f4b241e7f23f", + "metadata": { + "id": "e5a1a272-a038-4b8f-aaaa-f4b241e7f23f" + }, + "source": [ + "## Update the TransformerBlock module" + ] + }, + { + "cell_type": "markdown", + "id": "255f70ac-9c2e-4328-8af7-1c298b8d4a18", + "metadata": { + "id": "255f70ac-9c2e-4328-8af7-1c298b8d4a18" + }, + "source": [ + "- At this stage, most of the hard work is already done; we can now update the `TransformerBlock` to use the code we implemented above\n", + "- This means we\n", + " - replace LayerNorm with RMSNorm\n", + " - remove dropout\n", + " - remove the `qkv_bias` setting\n", + " - add the `dtype` setting" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "2e110721-bf2b-42b3-989a-1635b1658af0", + "metadata": { + "id": "2e110721-bf2b-42b3-989a-1635b1658af0" + }, + "outputs": [], + "source": [ + "class TransformerBlock(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.att = MultiHeadAttention(\n", + " d_in=cfg[\"emb_dim\"],\n", + " d_out=cfg[\"emb_dim\"],\n", + " context_length=cfg[\"context_length\"],\n", + " num_heads=cfg[\"n_heads\"],\n", + " dtype=cfg[\"dtype\"] # NEW\n", + " # dropout=cfg[\"drop_rate\"],\n", + " # qkv_bias=cfg[\"qkv_bias\"]\n", + " )\n", + " self.ff = FeedForward(cfg)\n", + "\n", + " ################################### NEW ###################################\n", + " # self.norm1 = LayerNorm(cfg[\"emb_dim\"])\n", + " # self.norm2 = LayerNorm(cfg[\"emb_dim\"])\n", + " self.norm1 = RMSNorm(cfg[\"emb_dim\"])\n", + " self.norm2 = RMSNorm(cfg[\"emb_dim\"])\n", + " ###########################################################################\n", + "\n", + " # self.drop_shortcut = nn.Dropout(cfg[\"drop_rate\"])\n", + "\n", + " def forward(self, x):\n", + " # Shortcut connection for attention block\n", + " shortcut = x\n", + " x = self.norm1(x)\n", + " x = self.att(x) # Shape [batch_size, num_tokens, emb_size]\n", + " # x = self.drop_shortcut(x)\n", + " x = x + shortcut # Add the original input back\n", + "\n", + " # Shortcut connection for feed-forward block\n", + " shortcut = x\n", + " x = self.norm2(x)\n", + " x = self.ff(x)\n", + " # x = self.drop_shortcut(x)\n", + " x = x + shortcut # Add the original input back\n", + "\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "ada953bc-e2c0-4432-a32d-3f7efa3f6e0f", + "metadata": { + "id": "ada953bc-e2c0-4432-a32d-3f7efa3f6e0f" + }, + "source": [ + "## Update the model class" + ] + }, + { + "cell_type": "raw", + "id": "aa79780d-74a8-4ee0-934a-9ad63205a02e", + "metadata": { + "id": "aa79780d-74a8-4ee0-934a-9ad63205a02e" + }, + "source": [ + "- As you may recall from [chapter 5](../01_main-chapter-code/ch05.ipynb), the `TransformerBlock` is a repeated block within the main model\n", + "- Our Llama model is almost complete; we just have to update the model code surrounding the `TransformerBlock`\n", + "- This means we\n", + " - remove absolute positional embeddings since we have RoPE embeddings now\n", + " - replace LayerNorm with RMSNorm\n", + " - remove dropout\n", + " - add the dtype setting" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "cf8240fe-5d7f-4e7e-b1ac-e0755aab5e79", + "metadata": { + "id": "cf8240fe-5d7f-4e7e-b1ac-e0755aab5e79" + }, + "outputs": [], + "source": [ + "# class GPTModel(nn.Module):\n", + "class Llama2Model(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", + " # self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n", + " # self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n", + "\n", + " self.trf_blocks = nn.Sequential(\n", + " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", + "\n", + " ################################### NEW ###################################\n", + " # self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n", + " self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n", + " ###########################################################################\n", + " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + "\n", + " def forward(self, in_idx):\n", + " batch_size, seq_len = in_idx.shape\n", + " tok_embeds = self.tok_emb(in_idx)\n", + " # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n", + " x = tok_embeds # + pos_embeds # Shape [batch_size, num_tokens, emb_size]\n", + " # x = self.drop_emb(x)\n", + " x = self.trf_blocks(x)\n", + " x = self.final_norm(x)\n", + " logits = self.out_head(x)\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60", + "metadata": { + "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60" + }, + "source": [ + "## Initialize model" + ] + }, + { + "cell_type": "markdown", + "id": "bG--zY-Ljj1f", + "metadata": { + "id": "bG--zY-Ljj1f" + }, + "source": [ + "- The model code is now complete, and we are ready to initialize it\n", + "- In [chapter 5](../01_main-chapter-code/ch05.ipynb), we used the following config file to specify the 124M-parameter GPT model:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "4b7428df-3d02-4ccd-97b5-a629bdabbe8f", + "metadata": { + "id": "4b7428df-3d02-4ccd-97b5-a629bdabbe8f" + }, + "outputs": [], + "source": [ + "GPT_CONFIG_124M = {\n", + " \"vocab_size\": 50257, # Vocabulary size\n", + " \"context_length\": 1024, # Context length\n", + " \"emb_dim\": 768, # Embedding dimension\n", + " \"n_heads\": 12, # Number of attention heads\n", + " \"n_layers\": 12, # Number of layers\n", + " \"drop_rate\": 0.1, # Dropout rate\n", + " \"qkv_bias\": False # Query-Key-Value bias\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "8bVi8uiBjw2T", + "metadata": { + "id": "8bVi8uiBjw2T" + }, + "source": [ + "- For reference, the 1.5B parameter GPT model config is shown below as well:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "tAOojV_mkEnd", + "metadata": { + "id": "tAOojV_mkEnd" + }, + "outputs": [], + "source": [ + "GPT_CONFIG_1558M = {\n", + " \"vocab_size\": 50257, # Vocabulary size\n", + " \"context_length\": 1024, # Context length\n", + " \"emb_dim\": 1600, # Embedding dimension\n", + " \"n_heads\": 25, # Number of attention heads\n", + " \"n_layers\": 48, # Number of layers\n", + " \"drop_rate\": 0.1, # Dropout rate\n", + " \"qkv_bias\": False # Query-Key-Value bias\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "HoGGRAGykQTE", + "metadata": { + "id": "HoGGRAGykQTE" + }, + "source": [ + "- Similarly, we can define a Llama 2 config file for the 7B model (we ignore the other larger models for simplicity here):" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18", + "metadata": { + "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18" + }, + "outputs": [], + "source": [ + "LLAMA2_CONFIG_7B = {\n", + " \"vocab_size\": 32000, # Vocabulary size\n", + " \"context_length\": 4096, # Context length\n", + " \"emb_dim\": 4096, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 32, # Number of layers\n", + " \"hidden_dim\": 11008, # NEW: Size of the intermediate dimension in FeedForward\n", + " \"dtype\": torch.bfloat16 # NEW: Lower-precision dtype to save memory\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "FAP7fiBzkaBz", + "metadata": { + "id": "FAP7fiBzkaBz" + }, + "source": [ + "- Using these settings, we can now initialize a Llama 2 7B model (note that this requires ~26 GB of memory)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7004d785-ac9a-4df5-8760-6807fc604686", + "metadata": { + "id": "7004d785-ac9a-4df5-8760-6807fc604686" + }, + "outputs": [], + "source": [ + "model = Llama2Model(LLAMA2_CONFIG_7B)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "6079f747-8f20-4c6b-8d38-7156f1101729", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6079f747-8f20-4c6b-8d38-7156f1101729", + "outputId": "78ab929e-ac78-4b16-ddb1-704d45ee69a8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters: 6,738,415,616\n" + ] + } + ], + "source": [ + "total_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Total number of parameters: {total_params:,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "Bx14NtzWk2wj", + "metadata": { + "id": "Bx14NtzWk2wj" + }, + "source": [ + "- As shown above, the model contains 6.7 billion parameters (commonly rounded and referred to as a 7B model)\n", + "- Additionally, we can calculate the memory requirements for this model using the code below:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a", + "outputId": "c0cbdcc8-dc46-44f7-a800-fbe888a3f9e9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float32 (PyTorch default): 52.27 GB\n", + "bfloat16: 26.13 GB\n" + ] + } + ], + "source": [ + "def model_memory_size(model, input_dtype=torch.float32):\n", + " total_params = 0\n", + " total_grads = 0\n", + " for param in model.parameters():\n", + " # Calculate total number of elements per parameter\n", + " param_size = param.numel()\n", + " total_params += param_size\n", + " # Check if gradients are stored for this parameter\n", + " if param.requires_grad:\n", + " total_grads += param_size\n", + "\n", + " # Calculate buffer size (non-parameters that require memory)\n", + " total_buffers = sum(buf.numel() for buf in model.buffers())\n", + "\n", + " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n", + " # We assume parameters and gradients are stored in the same type as input dtype\n", + " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n", + " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n", + "\n", + " # Convert bytes to gigabytes\n", + " total_memory_gb = total_memory_bytes / (1024**3)\n", + "\n", + " return total_memory_gb\n", + "\n", + "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n", + "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")" + ] + }, + { + "cell_type": "markdown", + "id": "zudd-5PulKFL", + "metadata": { + "id": "zudd-5PulKFL" + }, + "source": [ + "- Lastly, we can also transfer the model to an NVIDIA or Apple Silicon GPU if applicable:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d", + "metadata": { + "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d" + }, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34", + "metadata": { + "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34" + }, + "source": [ + "## Load tokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005", + "metadata": { + "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005" + }, + "source": [ + "- In this section, we are going to load the tokenizer for the model\n", + "- Llama 2 uses Google's [SentencePiece](https://github.com/google/sentencepiece) tokenizer instead of OpenAI's [Tiktoken](https://github.com/openai/tiktoken) (but Llama 3 uses Tiktoken)\n", + "- Meta AI shared the original Llama 2 model weights and tokenizer vocabulary on the Hugging Face Hub\n", + "- We will download the tokenizer vocabulary from the Hub and load it into SentencePiece\n", + "- Uncomment and run the following code to install the required libraries:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "768989ea-dc60-4dc8-ae84-cbb3fd224422", + "metadata": { + "id": "768989ea-dc60-4dc8-ae84-cbb3fd224422" + }, + "outputs": [], + "source": [ + "# !pip install huggingface_hub sentencepiece" + ] + }, + { + "cell_type": "markdown", + "id": "KbnlzsbYmJU6", + "metadata": { + "id": "KbnlzsbYmJU6" + }, + "source": [ + "- Please note that Meta AI requires that you accept the Llama 2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) repository to accept the terms\n", + "- Next, you will need to create an access token; to generate an access token, click on the profile picture in the upper right and click on \"Settings\"\n", + "\n", + "\n", + "\n", + "\n", + "- Then, create and copy the access token so you can copy & paste it into the next code cell\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "3357a230-b678-4691-a238-257ee4e80185", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3357a230-b678-4691-a238-257ee4e80185", + "outputId": "d326d32c-fa8d-4f2b-84d5-a1b8f35dd387" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", + "Token is valid (permission: read).\n", + "Your token has been saved to /root/.cache/huggingface/token\n", + "Login successful\n" + ] + } + ], + "source": [ + "from huggingface_hub import login\n", + "\n", + "login(token=\"hf_...\") # Insert your token here" + ] + }, + { + "cell_type": "markdown", + "id": "IxGh6ZYQo0VN", + "metadata": { + "id": "IxGh6ZYQo0VN" + }, + "source": [ + "- After login via the access token, which is necessary to verify that we accepted the Llama 2 licensing terms, we can now download the tokenizer vocabulary:" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", + "outputId": "82bc5037-c86c-46c2-b374-269f9d09599a" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from huggingface_hub import hf_hub_download\n", + "\n", + "tokenizer_file = hf_hub_download(\n", + " repo_id=\"meta-llama/Llama-2-7b\",\n", + " filename=\"tokenizer.model\",\n", + " cache_dir=\".\")" + ] + }, + { + "cell_type": "markdown", + "id": "gp7iQ8cXAJLv", + "metadata": { + "id": "gp7iQ8cXAJLv" + }, + "source": [ + "- To provide a more familiar interface for the tokenizer, we define a small `LlamaTokenizer` wrapper class:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "Ef4WxhjOBOOc", + "metadata": { + "id": "Ef4WxhjOBOOc" + }, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "\n", + "class LlamaTokenizer:\n", + " def __init__(self, filepath):\n", + " sp = spm.SentencePieceProcessor()\n", + " sp.load(tokenizer_file)\n", + " self.tokenizer = sp\n", + "\n", + " def encode(self, text):\n", + " return self.tokenizer.encode_as_ids(text)\n", + "\n", + " def decode(self, ids):\n", + " return self.tokenizer.decode_pieces(ids)\n", + "\n", + "\n", + "tokenizer = LlamaTokenizer(tokenizer_file)" + ] + }, + { + "cell_type": "markdown", + "id": "NVhmFeX3pT_M", + "metadata": { + "id": "NVhmFeX3pT_M" + }, + "source": [ + "- We can now use the `generate` function to have the Llama 2 model generate new text:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e", + "outputId": "d733bc0a-5136-4c33-d70d-36056f1e8329" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output text:\n", + " Every effort movesαllRadius deletingpretccappedRadius zas Parte Material Ку términчной herousztusllRadiusotto кра liberotto siguientesagnost#{ (@topicquez restored log\n" + ] + } + ], + "source": [ + "from previous_chapters import generate, text_to_token_ids, token_ids_to_text\n", + "\n", + "\n", + "torch.manual_seed(123)\n", + "\n", + "token_ids = generate(\n", + " model=model,\n", + " idx=text_to_token_ids(\"Every effort moves\", tokenizer).to(device),\n", + " max_new_tokens=30,\n", + " context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n", + " top_k=1,\n", + " temperature=1.0\n", + ")\n", + "\n", + "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" + ] + }, + { + "cell_type": "markdown", + "id": "93WTtAA5paYV", + "metadata": { + "id": "93WTtAA5paYV" + }, + "source": [ + "- Of course, as we can see above, the text is nonsensical since we haven't trained the Llama 2 model yet\n", + "- In the next section, instead of training it ourselves, which would cost tens to hundreds of thousands of dollars, we load the pretrained weights from Meta AI" + ] + }, + { + "cell_type": "markdown", + "id": "f63cc248-1d27-4eb6-aa50-173b436652f8", + "metadata": { + "id": "f63cc248-1d27-4eb6-aa50-173b436652f8" + }, + "source": [ + "## Load pretrained weights" + ] + }, + { + "cell_type": "markdown", + "id": "aKeN7rUfqZMI", + "metadata": { + "id": "aKeN7rUfqZMI" + }, + "source": [ + "- We are loading the [\"meta-llama/Llama-2-7b\"](https://huggingface.co/meta-llama/Llama-2-7b) base model below, which is a simple text completion model before finetuning\n", + "- Alternatively, you can load the instruction-finetuned and aligned [\"meta-llama/Llama-2-7b-chat\"](https://huggingface.co/meta-llama/Llama-2-7b-chat) model by modifying the string in the next code cell accordingly" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4", + "metadata": { + "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4" + }, + "outputs": [], + "source": [ + "weights_file = hf_hub_download(\n", + " repo_id=\"meta-llama/Llama-2-7b\",\n", + " filename=\"consolidated.00.pth\",\n", + " cache_dir=\".\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "e67cca5c-ba4b-4be5-85c7-fdceae8a5701", + "metadata": { + "id": "e67cca5c-ba4b-4be5-85c7-fdceae8a5701" + }, + "outputs": [], + "source": [ + "weights = torch.load(weights_file, weights_only=True)" + ] + }, + { + "cell_type": "markdown", + "id": "-15SJ7btq2zE", + "metadata": { + "id": "-15SJ7btq2zE" + }, + "source": [ + "- The `weights` contains the following tensors (only the first 15 are shown for simplicity):" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49", + "outputId": "01721809-ace1-4a7a-ab54-8fad2e8f54a6" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['tok_embeddings.weight',\n", + " 'norm.weight',\n", + " 'output.weight',\n", + " 'layers.0.attention.wq.weight',\n", + " 'layers.0.attention.wk.weight',\n", + " 'layers.0.attention.wv.weight',\n", + " 'layers.0.attention.wo.weight',\n", + " 'layers.0.feed_forward.w1.weight',\n", + " 'layers.0.feed_forward.w2.weight',\n", + " 'layers.0.feed_forward.w3.weight',\n", + " 'layers.0.attention_norm.weight',\n", + " 'layers.0.ffn_norm.weight',\n", + " 'layers.1.attention.wq.weight',\n", + " 'layers.1.attention.wk.weight',\n", + " 'layers.1.attention.wv.weight']" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(weights.keys())[:15]" + ] + }, + { + "cell_type": "markdown", + "id": "UeeSpnunrDFB", + "metadata": { + "id": "UeeSpnunrDFB" + }, + "source": [ + "- The following function, modeled after the `load_weights_into_gpt` function in [chapter 5](../01_main-chapter-code/ch05.ipynb), loads the pretrained weights into our Llama 2 model:" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65", + "metadata": { + "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65" + }, + "outputs": [], + "source": [ + "def assign(left, right):\n", + " if left.shape != right.shape:\n", + " raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n", + "\n", + " if isinstance(right, torch.Tensor):\n", + " return torch.nn.Parameter(right.clone().detach())\n", + " else:\n", + " return torch.nn.Parameter(torch.tensor(right))\n", + "\n", + "\n", + "def load_weights_into_llama(model, param_config, params):\n", + " model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n", + "\n", + " for l in range(param_config[\"n_layers\"]):\n", + "\n", + " # Load attention weights\n", + " model.trf_blocks[l].att.W_query.weight = assign(\n", + " model.trf_blocks[l].att.W_query.weight,\n", + " params[f\"layers.{l}.attention.wq.weight\"]\n", + " )\n", + " model.trf_blocks[l].att.W_key.weight = assign(\n", + " model.trf_blocks[l].att.W_key.weight,\n", + " params[f\"layers.{l}.attention.wk.weight\"]\n", + " )\n", + " model.trf_blocks[l].att.W_value.weight = assign(\n", + " model.trf_blocks[l].att.W_value.weight,\n", + " params[f\"layers.{l}.attention.wv.weight\"]\n", + " )\n", + " model.trf_blocks[l].att.out_proj.weight = assign(\n", + " model.trf_blocks[l].att.out_proj.weight,\n", + " params[f\"layers.{l}.attention.wo.weight\"]\n", + " )\n", + " model.trf_blocks[l].norm1.weight = assign(\n", + " model.trf_blocks[l].norm1.weight,\n", + " params[f\"layers.{l}.attention_norm.weight\"]\n", + " )\n", + "\n", + " # Load FeedForward weights\n", + " model.trf_blocks[l].ff.fc1.weight = assign(\n", + " model.trf_blocks[l].ff.fc1.weight,\n", + " params[f\"layers.{l}.feed_forward.w1.weight\"]\n", + " )\n", + " # For some reason w2 and w3 are provided in the wrong order in the weights file\n", + " model.trf_blocks[l].ff.fc2.weight = assign(\n", + " model.trf_blocks[l].ff.fc2.weight,\n", + " params[f\"layers.{l}.feed_forward.w3.weight\"]\n", + " )\n", + " model.trf_blocks[l].ff.fc3.weight = assign(\n", + " model.trf_blocks[l].ff.fc3.weight,\n", + " params[f\"layers.{l}.feed_forward.w2.weight\"]\n", + " )\n", + " model.trf_blocks[l].norm2.weight = assign(\n", + " model.trf_blocks[l].norm2.weight,\n", + " params[f\"layers.{l}.ffn_norm.weight\"]\n", + " )\n", + "\n", + " # Load output layer weights\n", + " model.final_norm.weight = assign(model.final_norm.weight, params[\"norm.weight\"])\n", + " model.out_head.weight = assign(model.out_head.weight, params[\"output.weight\"])\n", + "\n", + "\n", + "load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "TDuv_Us2rNvk", + "metadata": { + "id": "TDuv_Us2rNvk" + }, + "source": [ + "- Next, we are ready to use the model for text generation" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "240987e8-a023-462e-9376-9edfb27559ec", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "240987e8-a023-462e-9376-9edfb27559ec", + "outputId": "59830005-42af-406b-c836-38a8f2d7b961" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output text:\n", + " Every effort has been made to ensure that the information contained in this website is accurate and up to date. However, the information is provided without any warranty\n" + ] + } + ], + "source": [ + "torch.manual_seed(123)\n", + "\n", + "token_ids = generate(\n", + " model=model,\n", + " idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n", + " max_new_tokens=30,\n", + " context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n", + " top_k=1,\n", + " temperature=0.\n", + ")\n", + "\n", + "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" + ] + }, + { + "cell_type": "markdown", + "id": "akyo7WNyF_YL", + "metadata": { + "id": "akyo7WNyF_YL" + }, + "source": [ + "- Tip: as mentioned earlier, this is the pretrained base model; if you want to use a model capable of following instructions, use the `\"meta-llama/Llama-2-7b-chat\"` model instead" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "nbvAV7vaz6yc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nbvAV7vaz6yc", + "outputId": "faa930dc-0db2-4095-b395-f97baef08903" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output text:\n", + " What do llamas eat?\n", + "Llamas are herbivores, which means they eat plants. They eat grass, leaves, and hay.\n" + ] + } + ], + "source": [ + "del model # to free up memory\n", + "\n", + "weights_file = hf_hub_download(\n", + " repo_id=\"meta-llama/Llama-2-7b-chat\",\n", + " filename=\"consolidated.00.pth\",\n", + " cache_dir=\".\"\n", + ")\n", + "\n", + "model = Llama2Model(LLAMA2_CONFIG_7B)\n", + "load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n", + "model.to(device);\n", + "\n", + "torch.manual_seed(123)\n", + "\n", + "token_ids = generate(\n", + " model=model,\n", + " idx=text_to_token_ids(\"What do llamas eat?\", tokenizer).to(device),\n", + " max_new_tokens=25,\n", + " context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n", + " top_k=1,\n", + " temperature=0.\n", + ")\n", + "\n", + "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch05/07_gpt_to_llama/previous_chapters.py b/ch05/07_gpt_to_llama/previous_chapters.py new file mode 100644 index 00000000..1ca678c7 --- /dev/null +++ b/ch05/07_gpt_to_llama/previous_chapters.py @@ -0,0 +1,63 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch +# +# This file collects all the relevant code that we covered thus far +# throughout Chapters 2-4. +# This file can be run as a standalone script. + +import torch + + +##################################### +# Chapter 5 +##################################### +def text_to_token_ids(text, tokenizer): + encoded = tokenizer.encode(text) + encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension + return encoded_tensor + + +def token_ids_to_text(token_ids, tokenizer): + flat = token_ids.squeeze(0) # remove batch dimension + return tokenizer.decode(flat.tolist()) + + +def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): + + # For-loop is the same as before: Get logits, and only focus on last time step + for _ in range(max_new_tokens): + idx_cond = idx[:, -context_size:] + with torch.no_grad(): + logits = model(idx_cond) + logits = logits[:, -1, :] + + # New: Filter logits with top_k sampling + if top_k is not None: + # Keep only top_k values + top_logits, _ = torch.topk(logits, top_k) + min_val = top_logits[:, -1] + logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) + + # New: Apply temperature scaling + if temperature > 0.0: + logits = logits / temperature + + # Apply softmax to get probabilities + probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) + + # Sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) + + # Otherwise same as before: get idx of the vocab entry with the highest logits value + else: + idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) + + if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified + break + + # Same as before: append sampled index to the running sequence + idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) + + return idx diff --git a/ch05/07_gpt_to_llama/requirements-extra.txt b/ch05/07_gpt_to_llama/requirements-extra.txt new file mode 100644 index 00000000..8646692b --- /dev/null +++ b/ch05/07_gpt_to_llama/requirements-extra.txt @@ -0,0 +1,2 @@ +huggingface_hub>=0.24.7 +sentencepiece>=0.1.99 \ No newline at end of file diff --git a/ch05/README.md b/ch05/README.md index 428cd326..3a725194 100644 --- a/ch05/README.md +++ b/ch05/README.md @@ -11,3 +11,4 @@ - [04_learning_rate_schedulers](04_learning_rate_schedulers) contains code implementing a more sophisticated training function including learning rate schedulers and gradient clipping - [05_bonus_hparam_tuning](05_bonus_hparam_tuning) contains an optional hyperparameter tuning script - [06_user_interface](06_user_interface) implements an interactive user interface to interact with the pretrained LLM +- [07_gpt_to_llama](07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama and loads pretrained weights from Meta AI