Skip to content

Commit

Permalink
minor fixes: Llama 3.2 standalone (rasbt#420)
Browse files Browse the repository at this point in the history
* minor fixes

* reformat rope base as float

---------

Co-authored-by: rasbt <mail@sebastianraschka.com>
  • Loading branch information
d-kleine and rasbt authored Oct 26, 2024
1 parent 1516de5 commit e8c2f96
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # NEW: Larger size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # NEW: Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
" \"rope_base\": 500_000.0, # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
" \"rope_freq\": None, # NEW: Additional configuration for adjusting the RoPE frequencies\n",
" \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n",
"}"
Expand Down Expand Up @@ -2060,7 +2060,7 @@
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"rope_freq\": None, # Additional configuration for adjusting the RoPE frequencies\n",
" \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n",
"}\n",
Expand All @@ -2073,7 +2073,7 @@
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n",
Expand Down Expand Up @@ -2447,7 +2447,7 @@
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n",
Expand All @@ -2466,7 +2466,7 @@
" \"n_layers\": 16, # NEW: Half the number of layers\n",
" \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
Expand Down
10 changes: 5 additions & 5 deletions ch05/07_gpt_to_llama/standalone-llama32.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@
" \"n_layers\": 16, # Number of layers\n",
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0,\n",
Expand All @@ -451,13 +451,13 @@
"\n",
"# LLAMA32_CONFIG = {\n",
"# \"vocab_size\": 128_256, # Vocabulary size\n",
"# \"context_length\": 131_000, # Context length\n",
"# \"context_length\": 131_072, # Context length\n",
"# \"emb_dim\": 3072, # Embedding dimension\n",
"# \"n_heads\": 24, # Number of attention heads\n",
"# \"n_layers\": 28, # Number of layers\n",
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
"# \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
"# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
"# \"rope_freq\": { # RoPE frequency scaling\n",
"# \"factor\": 32.0,\n",
Expand Down Expand Up @@ -697,7 +697,6 @@
" def __init__(self, model_path):\n",
" assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
" mergeable_ranks = load_tiktoken_bpe(model_path)\n",
" num_base_tokens = len(mergeable_ranks)\n",
"\n",
" self.special_tokens = {\n",
" \"<|begin_of_text|>\": 128000,\n",
Expand Down Expand Up @@ -1013,7 +1012,8 @@
"\n",
"\n",
"load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)\n",
"model.to(device);"
"model.to(device)\n",
"del combined_weights # free up memory"
]
},
{
Expand Down

0 comments on commit e8c2f96

Please sign in to comment.