@@ -5536,12 +5536,14 @@ class StackLayer(LayerBase):
55365536 """
55375537 layer_class = "stack"
55385538
5539- def __init__ (self , axis = None , ** kwargs ):
5539+ def __init__ (self , axis = None , out_spatial_dim = None , ** kwargs ):
55405540 """
55415541 :param int|None axis: new axis.
55425542 If not given, will use Data.get_default_new_axis_for_dim_tag(<spatial>),
55435543 i.e. some reasonable default for a new spatial axis.
5544+ :param DimensionTag|None out_spatial_dim:
55445545 """
5546+ out_spatial_dim # noqa # handled in get_out_data_from_opts
55455547 super (StackLayer , self ).__init__ (** kwargs )
55465548 axis_ , common_source = self ._get_axis_and_common (self .sources )
55475549 if axis is None :
@@ -5558,24 +5560,28 @@ def _get_axis_and_common(cls, sources):
55585560 :param list[LayerBase] sources:
55595561 :rtype: (int,Data)
55605562 """
5561- from returnn .tf .util .basic import DimensionTag
55625563 common_source = Data .get_common_data ([src .output for src in sources ]).copy_template ()
5563- tag = DimensionTag (kind = DimensionTag .Types .Spatial , dimension = 1 )
5564- return common_source .get_default_new_axis_for_dim_tag (tag ), common_source
5564+ dummy_tag = DimensionTag (kind = DimensionTag .Types .Spatial , dimension = 1 )
5565+ return common_source .get_default_new_axis_for_dim_tag (dummy_tag ), common_source
55655566
55665567 @classmethod
5567- def get_out_data_from_opts (cls , name , sources , axis = None , ** kwargs ):
5568+ def get_out_data_from_opts (cls , name , sources , axis = None , out_spatial_dim = None , ** kwargs ):
55685569 """
55695570 :param str name:
55705571 :param list[LayerBase] sources:
55715572 :param int|None axis:
5573+ :param DimensionTag|None out_spatial_dim:
55725574 :rtype: Data
55735575 """
55745576 axis_ , common_source = cls ._get_axis_and_common (sources )
55755577 if axis is None :
55765578 axis = axis_
55775579 out = common_source .copy_template (name = "%s_output" % name )
5578- out = out .copy_add_spatial_dim (spatial_dim_axis = axis , dim = len (sources ))
5580+ if not out_spatial_dim :
5581+ out_spatial_dim = DimensionTag (
5582+ kind = DimensionTag .Types .Spatial , description = "%s:stack" % name , dimension = len (sources ))
5583+ assert out_spatial_dim .dimension == len (sources )
5584+ out = out .copy_add_dim_by_tag (axis = axis , dim_tag = out_spatial_dim , unbroadcast = True )
55795585 return out
55805586
55815587
0 commit comments