From a5405c255d9abaa2c2b162e9b65cf8912039bdb0 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 3 Oct 2024 20:45:40 -0500 Subject: [PATCH] Cos-sin fix in Llama 2 bonus notebook (#381) --- .../converting-gpt-to-llama2.ipynb | 1164 ++++++++++++++++- 1 file changed, 1142 insertions(+), 22 deletions(-) diff --git a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb index 4454af0e..731ab98a 100644 --- a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb +++ b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb @@ -76,7 +76,7 @@ "base_uri": "https://localhost:8080/" }, "id": "34a9a440-84c2-42cc-808b-38677cb6af8a", - "outputId": "7ce8fe41-1c24-4f0b-a8d9-352b4af1b46b" + "outputId": "8118963b-3c72-43af-874b-439ffebdc94c" }, "outputs": [ { @@ -578,8 +578,8 @@ " values = values.transpose(1, 2)\n", "\n", " ################################### NEW ###################################\n", - " keys = compute_rope(keys, self.sin, self.cos)\n", - " queries = compute_rope(queries, self.sin, self.cos)\n", + " keys = compute_rope(keys, self.cos, self.sin)\n", + " queries = compute_rope(queries, self.cos, self.sin)\n", " ###########################################################################\n", "\n", " # Compute scaled dot-product attention (aka self-attention) with a causal mask\n", @@ -916,7 +916,7 @@ "base_uri": "https://localhost:8080/" }, "id": "6079f747-8f20-4c6b-8d38-7156f1101729", - "outputId": "1ca50091-a20c-4a44-b806-9985a5e64135" + "outputId": "0a0eb34b-1a21-4c11-804f-b40007bda5a3" }, "outputs": [ { @@ -952,7 +952,7 @@ "base_uri": "https://localhost:8080/" }, "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a", - "outputId": "b157b5ac-d37c-4b71-f609-45a91f7ed93a" + "outputId": "11ced939-556d-4511-d5c0-10a94ed3df32" }, "outputs": [ { @@ -1085,7 +1085,7 @@ "base_uri": "https://localhost:8080/" }, "id": "3357a230-b678-4691-a238-257ee4e80185", - "outputId": "7d4adc4b-53cf-4099-a45f-2fb4fd25edc4" + "outputId": "768ed6af-ce14-40bc-ca18-117b4b448269" }, "outputs": [ { @@ -1126,10 +1126,24 @@ "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", "metadata": { "colab": { - "base_uri": "https://localhost:8080/" + "base_uri": "https://localhost:8080/", + "height": 153, + "referenced_widgets": [ + "e6c75a6aa7b942fe84160e286e3acb3d", + "08f0bf9459bd425498a5cb236f9d4a72", + "10251d6f724e43788c41d4b7879cbfd3", + "53a973c0853b44418698136bd04df039", + "bdb071e7145a4007ae01599333e72612", + "6b1821a7f4574e3aba09c1e410cc81e4", + "8c2873eaec3445888ad3d54ad7387950", + "0c8f7044966e4207b12352503c67dcbb", + "8b5951213c9e4798a258146d61d02d11", + "2c05df3f91e64df7b33905b1065a76f7", + "742ae5487f2648fcae7ca8e22c7f8db9" + ] }, "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", - "outputId": "aa18fccc-6533-4446-f57b-546068ad518c" + "outputId": "c230fec9-5c71-4a41-90ab-8a34d114ea01" }, "outputs": [ { @@ -1143,6 +1157,20 @@ "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e6c75a6aa7b942fe84160e286e3acb3d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer.model: 0%| | 0.00/500k [00:00