Skip to content

Commit

Permalink
Fix attention masking error; Disable last layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Werning committed Sep 25, 2024
1 parent a33a81c commit 66309ab
Showing 1 changed file with 30 additions and 29 deletions.
59 changes: 30 additions & 29 deletions padertorch/contrib/aw/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@
def generate_future_mask(sz: int) -> torch.Tensor:
"""Generates an upper-triangular matrix of -inf, with zeros on diag.
Used for forward model, which cannot see future tokens."""
return torch.tril(torch.ones(sz, sz) * float('-inf'), diagonal=-1)
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

class TransformerDecoder(nn.Module):

def __init__(self, *,
embed_dim=768,
depth=12,
num_heads=12,
output_dim=None,
input_dim=None,
mlp_ratio=4.0,
block_factory=AttentionBlockFactory(),
norm_layer=nn.LayerNorm,
dropout=0,
attn_dropout=0,
layer_dropout=0,
use_cls_token=False,
def __init__(self, *,
embed_dim=768,
depth=12,
num_heads=12,
output_dim=None,
input_dim=None,
mlp_ratio=4.0,
block_factory=AttentionBlockFactory(),
norm_layer=nn.LayerNorm,
dropout=0,
attn_dropout=0,
layer_dropout=0,
use_cls_token=False,
rel_pos_bias_factory: Optional[RelativePositionalBiasFactory] = None,
init_mode="default",
return_weights=False,
Expand All @@ -41,14 +41,14 @@ def __init__(self, *,
self.blocks = nn.ModuleList(
[
block_factory(
embed_dim,
num_heads,
mlp_ratio,
norm_layer=norm_layer,
dropout=dropout,
attn_dropout=attn_dropout,
embed_dim,
num_heads,
mlp_ratio,
norm_layer=norm_layer,
dropout=dropout,
attn_dropout=attn_dropout,
rel_pos_bias_factory=rel_pos_bias_factory
)
)
for _ in range(depth)
]
)
Expand All @@ -73,20 +73,20 @@ def __init__(self, *,
assert self.blocks[0].style == "pre-ln"
self.return_weights = return_weights
self.reset_parameters(init_mode=init_mode)

def forward(self, x):
seq_len = x.shape[-2]

src_mask = generate_future_mask(seq_len).to(x.device)

if self.use_cls_token:
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)

if self.in_proj is not None:
x = self.in_proj(x)
if self.apply_layernorm_before:
x = self.layer_norm(x)

# if self.apply_layernorm_before:
# x = self.layer_norm(x)

position_bias = None
attn_weights_list = []
Expand All @@ -100,8 +100,9 @@ def forward(self, x):
x, attn_weights = blk(x, attn_mask=src_mask, return_weights=self.return_weights)
attn_weights_list.append(attn_weights)

if not self.apply_layernorm_before:
x = self.layer_norm(x)
# disabled, causes errors with gradients
# if not self.apply_layernorm_before:
# x = self.layer_norm(x)

if self.out_proj is not None:
x = self.out_proj(x)
Expand Down

0 comments on commit 66309ab

Please sign in to comment.