Skip to content

Linear to Conv2d transform for static attention #9025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
self.attention_qkv_bias = config.attention_qkv_bias
self.use_qk_norm = config.use_qk_norm
self.use_conv2d = False

assert not self.use_qk_norm, "QK norm not supported in static attention yet"
self.wqs = nn.ModuleList(
Expand Down Expand Up @@ -255,9 +256,25 @@ def forward(
in_cache_state = kwargs.get("in_cache_state")
out_cache_state = kwargs.get("out_cache_state")

bsz, seq_len, dim = x.shape
if self.use_conv2d:
x = x.reshape(bsz, seq_len, 1, dim).transpose(1, 3)

new_qs = [self.wqs[i](x) for i in range(self.n_heads)]
new_ks = [self.wks[i](x) for i in range(self.n_kv_heads)]
new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)]

if self.use_conv2d:

def from_conv2ds(ts):
return [
t.reshape(bsz, self.head_dim, seq_len).transpose(1, 2) for t in ts
]

new_qs = from_conv2ds(new_qs)
new_ks = from_conv2ds(new_ks)
new_vs = from_conv2ds(new_vs)

new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]
all_ks = []
Expand All @@ -282,7 +299,14 @@ def forward(
heads.append(attn @ all_vs[kv_idx])

y = torch.cat(heads, dim=-1)
y = self.wo(y)
if self.use_conv2d:
y = (
self.wo(y.reshape(bsz, seq_len, 1, -1).transpose(1, 3))
.transpose(1, 3)
.reshape(bsz, seq_len, -1)
)
else:
y = self.wo(y)
return y, {"out_cache_state": out_cache_state}

def load_weights_from_attention_mha(self, other: AttentionMHA):
Expand All @@ -300,3 +324,44 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
)

self.wo.weight.data.copy_(other.wo.weight)

def linear_to_conv2d(self):
def transfer_weight(linear, conv2d):
conv2d.weight.data.copy_(linear.weight[:, :, None, None])
return conv2d

self.wqs = nn.ModuleList(
[
transfer_weight(
linear,
nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias),
)
for linear in self.wqs
]
)
self.wks = nn.ModuleList(
[
transfer_weight(
linear,
nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias),
)
for linear in self.wks
]
)
self.wvs = nn.ModuleList(
[
transfer_weight(
linear,
nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias),
)
for linear in self.wvs
]
)
self.wo = transfer_weight(
self.wo,
nn.Conv2d(
self.n_heads * self.head_dim, self.dim, 1, bias=self.attention_qkv_bias
),
)

self.use_conv2d = True
56 changes: 31 additions & 25 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,38 @@ def setUp(self):
torch.manual_seed(42)

def test_without_cache(self):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
)
layer_id = 0
rope = Rope(config)
attn_mha = AttentionMHA(config, layer_id, rope).eval()
static_attn = StaticAttention(config, layer_id, rope).eval()
static_attn.load_weights_from_attention_mha(attn_mha)
def test(use_conv2d):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
)
layer_id = 0
rope = Rope(config)
attn_mha = AttentionMHA(config, layer_id, rope).eval()
static_attn = StaticAttention(config, layer_id, rope).eval()
static_attn.load_weights_from_attention_mha(attn_mha)
if use_conv2d:
static_attn.linear_to_conv2d()

x = torch.rand(1, config.max_seq_len, config.dim)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
mask = torch.triu(
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
diagonal=1,
)
y, _ = static_attn(
x,
freqs_cos,
freqs_sin,
mask=mask,
)
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())

x = torch.rand(1, config.max_seq_len, config.dim)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
mask = torch.triu(
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
diagonal=1,
)
y, _ = static_attn(
x,
freqs_cos,
freqs_sin,
mask=mask,
)
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
test(True)
test(False)

def test_hf_rope_without_cache(self):
config = ModelArgs(
Expand Down
Loading