Skip to content

Commit 40c08dc

Browse files
added query-key norm to accomodate OLMo2 (#1894)
Co-authored-by: Andrei-Aksionov <aksionau.andrei@gmail.com> Co-authored-by: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com>
1 parent 17a58df commit 40c08dc

File tree

3 files changed

+14
-0
lines changed

3 files changed

+14
-0
lines changed

litgpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Config:
2525
# Transformer block (structure, normalizations)
2626
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
2727
norm_eps: float = 1e-5
28+
norm_qk: bool = False
2829
post_attention_norm: bool = False
2930
post_mlp_norm: bool = False
3031
parallel_residual: bool = True

litgpt/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,13 @@ def __init__(self, config: Config, block_idx: int) -> None:
341341
config.sliding_window_size is not None and
342342
block_idx % config.sliding_window_layer_stride == 0
343343
)
344+
345+
if config.norm_qk:
346+
self.norm_q = config.norm_class(config.head_size * config.n_head, eps=config.norm_eps)
347+
self.norm_k = config.norm_class(config.head_size * config.n_query_groups, eps=config.norm_eps)
348+
else:
349+
self.norm_q = self.norm_k = None
350+
344351
self.config = config
345352
self.block_idx = block_idx
346353

@@ -377,6 +384,10 @@ def forward(
377384
# Split qkv into query, key and value matrices.
378385
q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*)
379386

387+
if self.config.norm_qk:
388+
q = self.norm_q(q)
389+
k = self.norm_k(k)
390+
380391
# To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the
381392
# embedding size (C) into num_heads (nh) and head_size (hs).
382393
q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs)

tests/test_readme.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def run_command(command):
3232

3333

3434
@pytest.mark.dependency()
35+
@pytest.mark.flaky(reruns=5, reruns_delay=2)
3536
def test_download_model():
3637
repo_id = str(REPO_ID).replace("\\", "/") # fix for Windows CI
3738
command = ["litgpt", "download", str(repo_id)]
@@ -48,6 +49,7 @@ def test_download_model():
4849

4950

5051
@pytest.mark.dependency()
52+
@pytest.mark.flaky(reruns=5, reruns_delay=2)
5153
def test_download_books():
5254
CUSTOM_TEXTS_DIR.mkdir(parents=True, exist_ok=True)
5355

0 commit comments

Comments
 (0)