@@ -1461,10 +1461,14 @@ class ScatterNdLayer(_ConcatInputLayer):
14611461 The inverse of :class:`GatherNdLayer`.
14621462 Mostly a wrapper for ``tf.scatter_nd``.
14631463
1464+ Note that "nd" is maybe a bit misleading.
1465+ While we operate on N-D tensors, the indices (``position``)
1466+ are into a single new dimension.
1467+
14641468 The input to the layer are the ``updates``, the ``indices`` are via the ``position`` argument.
14651469 The indices are into the newly constructed output dimension.
14661470 The output shape is constructed via the common shape of the input, the position,
1467- and the the unique common axis (if not unique, we would need to introduce an option to specify it)
1471+ and the unique common axis (if not unique, we would need to introduce an option to specify it)
14681472 is replaced by the given output dimension (currently via ``output_dim_via_time_from``).
14691473
14701474 Examples::
@@ -1489,18 +1493,26 @@ class ScatterNdLayer(_ConcatInputLayer):
14891493 """
14901494 layer_class = "scatter_nd"
14911495
1492- def __init__ (self , position , position_axis , output_dim_via_time_from , filter_invalid_indices = False , ** kwargs ):
1496+ def __init__ (self , position , position_axis , output_dim_via_time_from = None , out_spatial_dim = None ,
1497+ filter_invalid_indices = False , ** kwargs ):
14931498 """
14941499 :param LayerBase position: indices into first axis (excluding batch) of the output
14951500 :param str|int position_axis: axis in `position` to replace by the output-dim
1496- :param LayerBase output_dim_via_time_from: use the time-dim from this layer as the output-dim
1501+ :param LayerBase|None output_dim_via_time_from: use the time-dim from this layer as the output-dim
1502+ :param DimensionTag|None out_spatial_dim:
14971503 :param bool filter_invalid_indices: allow for indices <0 or >= output_dim, which will be discarded in the output
14981504 """
14991505 super (ScatterNdLayer , self ).__init__ (** kwargs )
1506+ assert (out_spatial_dim or output_dim_via_time_from ) and not (out_spatial_dim and output_dim_via_time_from ), (
1507+ "%s: provide either out_spatial_dim or output_dim_via_time_from but not both" % self )
1508+ if not out_spatial_dim :
1509+ out_spatial_dim = output_dim_via_time_from .output .get_time_dim_tag ()
1510+ assert out_spatial_dim .is_dim_known (), (
1511+ "%s: out_spatial_dim %s must have a known (dynamic or static) dim" % (self , out_spatial_dim ))
15001512 self .position = position
15011513 common , output , replace_common_axis , input_extra_axes = self ._get_axes (
15021514 input_data = self .input_data , position = position .output , position_axis = position_axis ,
1503- output_dim_via_time_from = output_dim_via_time_from . output )
1515+ out_spatial_dim = out_spatial_dim )
15041516 pos_v = position .output .placeholder
15051517 pos_ndim = position .output .batch_ndim
15061518 assert 0 <= replace_common_axis < pos_ndim
@@ -1539,12 +1551,12 @@ def get_dep_layers(self):
15391551 return super (ScatterNdLayer , self ).get_dep_layers () + [self .position ]
15401552
15411553 @classmethod
1542- def _get_axes (cls , input_data , position , position_axis , output_dim_via_time_from ):
1554+ def _get_axes (cls , input_data , position , position_axis , out_spatial_dim ):
15431555 """
15441556 :param Data input_data: updates
15451557 :param Data position: indices
15461558 :param str|int position_axis: axis in `position` to replace by the output-dim
1547- :param Data output_dim_via_time_from :
1559+ :param DimensionTag out_spatial_dim :
15481560 :rtype: (Data, Data, int, list[int])
15491561 :return: common, output, axis, input_extra_axes
15501562 """
@@ -1570,26 +1582,26 @@ def _get_axes(cls, input_data, position, position_axis, output_dim_via_time_from
15701582 assert position_axis != position .batch_dim_axis
15711583 if common .time_dim_axis is None :
15721584 common .time_dim_axis = position_axis
1573- output_dim = output_dim_via_time_from .batch_shape [output_dim_via_time_from .time_dim_axis ]
1574- output_size = output_dim_via_time_from .size_placeholder .get (
1575- output_dim_via_time_from .time_dim_axis_excluding_batch , None )
1576- output = common .copy_template_replace_dim (axis = position_axis , new_dim = output_dim , new_size = output_size )
1585+ output = common .copy_template_replace_dim_tag (axis = position_axis , new_dim_tag = out_spatial_dim )
15771586 return common , output , position_axis , input_extra_axes
15781587
15791588 @classmethod
1580- def get_out_data_from_opts (cls , name , sources , position , position_axis , output_dim_via_time_from , ** kwargs ):
1589+ def get_out_data_from_opts (cls , name , sources , position , position_axis ,
1590+ output_dim_via_time_from = None , out_spatial_dim = None ,
1591+ ** kwargs ):
15811592 """
15821593 :param str name:
15831594 :param list[LayerBase] sources:
15841595 :param LayerBase position:
15851596 :param str|int position_axis: axis in `position` to replace by the output-dim
1586- :param LayerBase output_dim_via_time_from:
1597+ :param LayerBase|None output_dim_via_time_from: use the time-dim from this layer as the output-dim
1598+ :param DimensionTag|None out_spatial_dim:
15871599 :rtype: Data
15881600 """
15891601 input_data = get_concat_sources_data_template (sources )
15901602 common , output , replace_common_axis , input_extra_axes = cls ._get_axes (
15911603 input_data = input_data , position = position .output , position_axis = position_axis ,
1592- output_dim_via_time_from = output_dim_via_time_from .output )
1604+ out_spatial_dim = out_spatial_dim if out_spatial_dim else output_dim_via_time_from .output . get_time_dim_tag () )
15931605 return output .copy_template (name = "%s_output" % name )
15941606
15951607 @classmethod
@@ -1601,7 +1613,8 @@ def transform_config_dict(cls, d, network, get_layer):
16011613 """
16021614 super (ScatterNdLayer , cls ).transform_config_dict (d , network = network , get_layer = get_layer )
16031615 d ["position" ] = get_layer (d ["position" ])
1604- d ["output_dim_via_time_from" ] = get_layer (d ["output_dim_via_time_from" ])
1616+ if d .get ("output_dim_via_time_from" , None ):
1617+ d ["output_dim_via_time_from" ] = get_layer (d ["output_dim_via_time_from" ])
16051618
16061619
16071620class LinearLayer (_ConcatInputLayer ):
0 commit comments