-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RelPosSelfAttention _rel_shift error, learned embedding #238
Comments
My first thought: This is probably related to some dim value, and then slice/gather on it. |
This comment was marked as resolved.
This comment was marked as resolved.
A bit more meta: With all our logic for dim tags, which should actually make it easier to avoid any reshape problems or other shaping problems, why do we still frequently run into such things? The answer is probably too much unnecessary complexity and thus bugs in some parts. But which parts really? What can we remove from it? How can we improve this situation? Related: rwth-i6/returnn#975 |
This test case triggers the bug: def test_rel_pos_self_attention_learnable():
class _Net(nn.Module):
# noinspection PyShadowingNames
def __init__(self, in_dim: nn.FeatureDim):
super().__init__()
self.self_att = nn.RelPosSelfAttention(
in_dim=in_dim, proj_dim=nn.FeatureDim("out", 5),
key_dim_total=nn.FeatureDim("key-dim-total", 21),
value_dim_total=nn.FeatureDim("value-dim-total", 33),
num_heads=3,
# Shawn et al 2018 style, old RETURNN way.
with_bias=False,
with_linear_pos=False,
with_pos_bias=False,
learnable_pos_emb=True,
learnable_pos_emb_clipping=3,
separate_pos_emb_per_head=False,
)
def __call__(self, x: nn.Tensor, *, axis: nn.Dim) -> nn.Tensor:
"""forward"""
return self.self_att(x, axis=axis)
in_dim = nn.FeatureDim("in", 12)
config, net_dict, net = dummy_config_net_dict(lambda: _Net(in_dim), with_axis=True, in_dim=in_dim)
pprint(net_dict)
dummy_run_net(config, net=net, seq_len=3) # ok
dummy_run_net(config, net=net, seq_len=3) # try again, to see that running again is ok.
dummy_run_net(config, net=net, seq_len=1) # ok
dummy_run_net(config, net=net, seq_len=4) # problem currently... Note that this test case also triggers some other unrelated bugs first, which are going to be fixed in rwth-i6/returnn#1199. |
Fixed via 5e223b2. |
To answer this question: The problem here was that we actually did manual dim math. In out_spatial_dim = spatial_dim - 1 + spatial_dim
...
remaining_dim = spatial_dim - self.clipping
...
cond.true, out_spatial_dim_ = nn.concat(
(left, remaining_dim),
(self.pos_emb, self.clipped_spatial_dim),
(right, remaining_dim))
out_spatial_dim_.declare_same_as(out_spatial_dim) And: self.clipped_spatial_dim = nn.SpatialDim(
f"{nn.NameCtx.current_ctx().get_abs_name()}:learned-rel-pos",
dimension=2 * clipping + 1) I.e.:
But the Maybe in this case, we could have detected this statically. But in the general case, there are always cases where we can not detect this at compilation time, and only at runtime. At some point, we planned to actually add such runtime checks for Edit I posted this here: rwth-i6/returnn#1200 |
Corresponding code in
_rel_shift
:This happens with search.
This happens only later, after a lot of sequences are already recognized.
The sequences are ordered by length, from short to long, and this now seems to happen with quite long sequences.
The text was updated successfully, but these errors were encountered: