Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
self.use_qk_norm = config.use_qk_norm
self.use_conv2d = False

assert not self.use_qk_norm, "QK norm not supported in static attention yet"
self.wqs = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
Expand Down Expand Up @@ -241,6 +240,13 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
self.rope = _Rope(rope.params.use_hf_rope)

if self.use_qk_norm:
self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)
self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)
else:
self.q_norm = torch.nn.Identity()
self.k_norm = torch.nn.Identity()

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -275,6 +281,10 @@ def from_conv2ds(ts):
new_ks = from_conv2ds(new_ks)
new_vs = from_conv2ds(new_vs)

if self.use_qk_norm:
new_qs = [self.q_norm(q) for q in new_qs]
new_ks = [self.k_norm(k) for k in new_ks]

new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]
all_ks = []
Expand Down Expand Up @@ -325,6 +335,13 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):

self.wo.weight.data.copy_(other.wo.weight)

if other.use_qk_norm:
self.use_qk_norm = True
self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps)
self.q_norm.load_state_dict(other.q_norm_fn.state_dict())
self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps)
self.k_norm.load_state_dict(other.k_norm_fn.state_dict())

def linear_to_conv2d(self):
def transfer_weight(linear, conv2d):
conv2d.weight.data.copy_(linear.weight[:, :, None, None])
Expand Down
9 changes: 6 additions & 3 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ def setUp(self):
torch.manual_seed(42)

def test_without_cache(self):
def test(use_conv2d):
def test(use_qk_norm, use_conv2d):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
use_qk_norm=use_qk_norm,
)
layer_id = 0
rope = Rope(config)
Expand All @@ -47,8 +48,10 @@ def test(use_conv2d):
)
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())

test(True)
test(False)
test(True, True)
test(True, False)
test(False, True)
test(False, False)

def test_hf_rope_without_cache(self):
config = ModelArgs(
Expand Down
Loading