Skip to content

Commit

Permalink
add value residual learning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 28, 2024
1 parent 61e484b commit 44f20aa
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 6 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,12 @@ $ python train.py
url = {https://api.semanticscholar.org/CorpusID:1505432}
}
```

```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
40 changes: 35 additions & 5 deletions nGPT_pytorch/nGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,18 @@ def __init__(
def forward(self, x, **kwargs):
residual = x

branch_out = l2norm(self.fn(x, **kwargs))
out = l2norm(residual.lerp(branch_out, self.branch_scale()))
out = self.fn(x, **kwargs)

tuple_output = isinstance(out, tuple)

if tuple_output:
out, *rest = out

out = l2norm(out)
out = l2norm(residual.lerp(out, self.branch_scale()))

if tuple_output:
out = (out, *rest)

return out

Expand Down Expand Up @@ -216,7 +226,9 @@ def forward(
self,
x,
mask = None,
rotary_embed: Module | None = None
rotary_embed: Module | None = None,
value_residual = None,
return_values = False
):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

Expand Down Expand Up @@ -245,6 +257,11 @@ def forward(
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')

# maybe value residual, from resformer paper

if exists(value_residual):
v = v + value_residual

# scale is sqrt(dk)

with self.sdpa_context_manager():
Expand All @@ -256,7 +273,12 @@ def forward(
)

out = self.merge_heads(out)
return self.to_out(out)
out = self.to_out(out)

if not return_values:
return out

return out, v

# feedforward

Expand Down Expand Up @@ -315,6 +337,7 @@ def __init__(
tied_embedding = False,
num_hyperspheres = 1,
causal = True,
add_value_residual = True,
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
s_logit_init: float = 1.,
Expand Down Expand Up @@ -344,6 +367,8 @@ def __init__(
self.causal = causal
alpha_init = default(alpha_init, 1. / depth)

self.add_value_residual = add_value_residual # https://arxiv.org/abs/2410.17897v1

self.token_embed = NormLinear_(dim, num_tokens)

self.rotary_embed = RotaryEmbedding(dim_head)
Expand Down Expand Up @@ -448,8 +473,13 @@ def forward(

tokens = token_embed[ids]

first_values = None

for attn, ff in self.layers:
tokens = attn(tokens, mask = mask, rotary_embed = rotary_embed)
tokens, values = attn(tokens, mask = mask, rotary_embed = rotary_embed, return_values = True, value_residual = first_values if self.add_value_residual else None)

first_values = default(first_values, values)

tokens = ff(tokens)

if exists(self.to_logits):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nGPT-pytorch"
version = "0.1.18"
version = "0.1.19"
description = "nGPT"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def base_decoding(
dim = 512,
depth = 8,
tied_embedding = True,
add_value_residual = True,
manual_norm_weights = not USE_PARAMETRIZE
).to(device)

Expand Down

0 comments on commit 44f20aa

Please sign in to comment.