Skip to content

Commit 88156c8

Browse files
authored
StackLayer, out_spatial_dim option (#809)
#597
1 parent 1b69acf commit 88156c8

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

returnn/tf/layers/basic.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5536,12 +5536,14 @@ class StackLayer(LayerBase):
55365536
"""
55375537
layer_class = "stack"
55385538

5539-
def __init__(self, axis=None, **kwargs):
5539+
def __init__(self, axis=None, out_spatial_dim=None, **kwargs):
55405540
"""
55415541
:param int|None axis: new axis.
55425542
If not given, will use Data.get_default_new_axis_for_dim_tag(<spatial>),
55435543
i.e. some reasonable default for a new spatial axis.
5544+
:param DimensionTag|None out_spatial_dim:
55445545
"""
5546+
out_spatial_dim # noqa # handled in get_out_data_from_opts
55455547
super(StackLayer, self).__init__(**kwargs)
55465548
axis_, common_source = self._get_axis_and_common(self.sources)
55475549
if axis is None:
@@ -5558,24 +5560,28 @@ def _get_axis_and_common(cls, sources):
55585560
:param list[LayerBase] sources:
55595561
:rtype: (int,Data)
55605562
"""
5561-
from returnn.tf.util.basic import DimensionTag
55625563
common_source = Data.get_common_data([src.output for src in sources]).copy_template()
5563-
tag = DimensionTag(kind=DimensionTag.Types.Spatial, dimension=1)
5564-
return common_source.get_default_new_axis_for_dim_tag(tag), common_source
5564+
dummy_tag = DimensionTag(kind=DimensionTag.Types.Spatial, dimension=1)
5565+
return common_source.get_default_new_axis_for_dim_tag(dummy_tag), common_source
55655566

55665567
@classmethod
5567-
def get_out_data_from_opts(cls, name, sources, axis=None, **kwargs):
5568+
def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None, **kwargs):
55685569
"""
55695570
:param str name:
55705571
:param list[LayerBase] sources:
55715572
:param int|None axis:
5573+
:param DimensionTag|None out_spatial_dim:
55725574
:rtype: Data
55735575
"""
55745576
axis_, common_source = cls._get_axis_and_common(sources)
55755577
if axis is None:
55765578
axis = axis_
55775579
out = common_source.copy_template(name="%s_output" % name)
5578-
out = out.copy_add_spatial_dim(spatial_dim_axis=axis, dim=len(sources))
5580+
if not out_spatial_dim:
5581+
out_spatial_dim = DimensionTag(
5582+
kind=DimensionTag.Types.Spatial, description="%s:stack" % name, dimension=len(sources))
5583+
assert out_spatial_dim.dimension == len(sources)
5584+
out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim, unbroadcast=True)
55795585
return out
55805586

55815587

0 commit comments

Comments
 (0)