Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2469,8 +2469,13 @@ class GatingLayer(_ConcatInputLayer):
"""
layer_class = "gating"

def __init__(self, activation, gate_activation="sigmoid", **kwargs):
super(GatingLayer, self).__init__(**kwargs)
def __init__(self, activation, gate_activation="sigmoid", out_dim=None, **kwargs):
"""
:param str activation:
:param str gate_activation:
:param Dim|None out_dim:
"""
super(GatingLayer, self).__init__(out_dim=out_dim, **kwargs)
from returnn.tf.util.basic import get_activation_function
act_func = get_activation_function(activation)
gate_act_func = get_activation_function(gate_activation)
Expand All @@ -2484,25 +2489,29 @@ def __init__(self, activation, gate_activation="sigmoid", **kwargs):
self.output.size_placeholder = self.input_data.size_placeholder.copy()

@classmethod
def get_out_data_from_opts(cls, name, sources, n_out=NotSpecified, **kwargs):
def get_out_data_from_opts(cls, name, sources, n_out=NotSpecified, out_dim=None, **kwargs):
"""
:param str name:
:param list[LayerBase] sources:
:param int|None|NotSpecified n_out:
:param Dim|None out_dim:
:rtype: Data
"""
input_data = get_concat_sources_data_template(sources)
assert not input_data.sparse
assert input_data.dim % 2 == 0
dim = input_data.dim // 2
new_dim_tag = FeatureDim("%s:gating" % name, dimension=dim)
if out_dim:
assert out_dim.dimension == dim
else:
out_dim = FeatureDim("%s:gating" % name, dimension=dim)
if n_out is not NotSpecified:
assert n_out == dim
return Data(
name="%s_output" % name,
dtype=input_data.dtype,
dim_tags=[
new_dim_tag if i == input_data.feature_dim_axis else d
out_dim if i == input_data.feature_dim_axis else d
for (i, d) in enumerate(input_data.dim_tags)],
sparse=False,
time_dim_axis=input_data.time_dim_axis,
Expand Down