Skip to content

Commit

Permalink
add value residual learning, proposed in iclr 2025
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 7, 2024
1 parent 37f82de commit dd77186
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,12 @@ $ accelerate launch train.py
url = {https://api.semanticscholar.org/CorpusID:273229218}
}
```

```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
39 changes: 35 additions & 4 deletions audiolm_pytorch/audiolm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def forward(
prefix_context = None,
prefix_context_mask = None,
return_kv_cache = False,
return_values = False,
value_residual: Tensor | None = None,
kv_cache = None
):
b, n, _, device = *x.shape, x.device
Expand Down Expand Up @@ -346,6 +348,13 @@ def forward(

q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

# for value residual learning

orig_v = v

if exists(value_residual):
v = 0.5 * (v + value_residual)

# kv cache

if exists(kv_cache):
Expand Down Expand Up @@ -383,10 +392,16 @@ def forward(
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)

if not return_kv_cache:
if not return_kv_cache and not return_values:
return out

return out, kv_cache
if return_kv_cache and not return_values:
return out, kv_cache

if return_values and not return_kv_cache:
return out, orig_v

return out, (kv_cache, orig_v)

# transformer

Expand All @@ -405,6 +420,7 @@ def __init__(
cond_as_self_attn_prefix = False,
rel_pos_bias = True,
flash_attn = False,
add_value_residual = True,
**kwargs
):
super().__init__()
Expand All @@ -431,6 +447,8 @@ def __init__(

self.norm = LayerNorm(dim)

self.add_value_residual = add_value_residual

def forward(
self,
x,
Expand Down Expand Up @@ -487,21 +505,34 @@ def forward(
prefix_context_mask = context_mask
)

# value residuals

self_attn_value_residual = None
cross_attn_value_residual = None

# transformer layers

for attn, cross_attn, ff in self.layers:

residual = x

x, layer_kv_cache = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask, kv_cache = next(kv_cache, None), return_kv_cache = True, **self_attn_kwargs)
x, (layer_kv_cache, values) = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask, kv_cache = next(kv_cache, None), return_kv_cache = True, return_values = True, value_residual = self_attn_value_residual, **self_attn_kwargs)

if self.add_value_residual:
self_attn_value_residual = default(self_attn_value_residual, values)

new_kv_cache.append(layer_kv_cache)

x = x + residual

if exists(cross_attn):
assert exists(context)

x = cross_attn(x, context = context, mask = context_mask) + x
cross_attend_out, values = cross_attn(x, context = context, mask = context_mask, return_values = True, value_residual = cross_attn_value_residual)
x = cross_attend_out + x

if self.add_value_residual:
cross_attn_value_residual = default(cross_attn_value_residual, values)

x = ff(x) + x

Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.1.4'
__version__ = '2.2.0'

0 comments on commit dd77186

Please sign in to comment.