Skip to content

Commit

Permalink
GatingLayer, declare_same_as on out_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 3, 2022
1 parent f2fb3ea commit 87a5e8a
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2944,18 +2944,16 @@ def get_out_data_from_opts(cls, name, sources, n_out=NotSpecified, out_dim=None,
input_data = get_concat_sources_data_template(sources)
assert not input_data.sparse
assert input_data.dim % 2 == 0
dim = input_data.dim // 2
out_dim_ = input_data.dim_tags[input_data.feature_dim_axis] // 2
if out_dim:
assert out_dim.dimension == dim
else:
out_dim = FeatureDim("%s:gating" % name, dimension=dim, auto_generated=True)
out_dim_.declare_same_as(out_dim)
if n_out is not NotSpecified:
assert n_out == dim
assert n_out == input_data.dim // 2
return Data(
name="%s_output" % name,
dtype=input_data.dtype,
dim_tags=[
out_dim if i == input_data.feature_dim_axis else d
out_dim_ if i == input_data.feature_dim_axis else d
for (i, d) in enumerate(input_data.dim_tags)],
sparse=False,
time_dim_axis=input_data.time_dim_axis,
Expand Down

0 comments on commit 87a5e8a

Please sign in to comment.