Skip to content

Commit 5ffe78a

Browse files
authored
ScatterNdLayer, out_spatial_dim option (#770)
#597
1 parent 5767590 commit 5ffe78a

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

returnn/tf/layers/basic.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,10 +1461,14 @@ class ScatterNdLayer(_ConcatInputLayer):
14611461
The inverse of :class:`GatherNdLayer`.
14621462
Mostly a wrapper for ``tf.scatter_nd``.
14631463
1464+
Note that "nd" is maybe a bit misleading.
1465+
While we operate on N-D tensors, the indices (``position``)
1466+
are into a single new dimension.
1467+
14641468
The input to the layer are the ``updates``, the ``indices`` are via the ``position`` argument.
14651469
The indices are into the newly constructed output dimension.
14661470
The output shape is constructed via the common shape of the input, the position,
1467-
and the the unique common axis (if not unique, we would need to introduce an option to specify it)
1471+
and the unique common axis (if not unique, we would need to introduce an option to specify it)
14681472
is replaced by the given output dimension (currently via ``output_dim_via_time_from``).
14691473
14701474
Examples::
@@ -1489,18 +1493,26 @@ class ScatterNdLayer(_ConcatInputLayer):
14891493
"""
14901494
layer_class = "scatter_nd"
14911495

1492-
def __init__(self, position, position_axis, output_dim_via_time_from, filter_invalid_indices=False, **kwargs):
1496+
def __init__(self, position, position_axis, output_dim_via_time_from=None, out_spatial_dim=None,
1497+
filter_invalid_indices=False, **kwargs):
14931498
"""
14941499
:param LayerBase position: indices into first axis (excluding batch) of the output
14951500
:param str|int position_axis: axis in `position` to replace by the output-dim
1496-
:param LayerBase output_dim_via_time_from: use the time-dim from this layer as the output-dim
1501+
:param LayerBase|None output_dim_via_time_from: use the time-dim from this layer as the output-dim
1502+
:param DimensionTag|None out_spatial_dim:
14971503
:param bool filter_invalid_indices: allow for indices <0 or >= output_dim, which will be discarded in the output
14981504
"""
14991505
super(ScatterNdLayer, self).__init__(**kwargs)
1506+
assert (out_spatial_dim or output_dim_via_time_from) and not (out_spatial_dim and output_dim_via_time_from), (
1507+
"%s: provide either out_spatial_dim or output_dim_via_time_from but not both" % self)
1508+
if not out_spatial_dim:
1509+
out_spatial_dim = output_dim_via_time_from.output.get_time_dim_tag()
1510+
assert out_spatial_dim.is_dim_known(), (
1511+
"%s: out_spatial_dim %s must have a known (dynamic or static) dim" % (self, out_spatial_dim))
15001512
self.position = position
15011513
common, output, replace_common_axis, input_extra_axes = self._get_axes(
15021514
input_data=self.input_data, position=position.output, position_axis=position_axis,
1503-
output_dim_via_time_from=output_dim_via_time_from.output)
1515+
out_spatial_dim=out_spatial_dim)
15041516
pos_v = position.output.placeholder
15051517
pos_ndim = position.output.batch_ndim
15061518
assert 0 <= replace_common_axis < pos_ndim
@@ -1539,12 +1551,12 @@ def get_dep_layers(self):
15391551
return super(ScatterNdLayer, self).get_dep_layers() + [self.position]
15401552

15411553
@classmethod
1542-
def _get_axes(cls, input_data, position, position_axis, output_dim_via_time_from):
1554+
def _get_axes(cls, input_data, position, position_axis, out_spatial_dim):
15431555
"""
15441556
:param Data input_data: updates
15451557
:param Data position: indices
15461558
:param str|int position_axis: axis in `position` to replace by the output-dim
1547-
:param Data output_dim_via_time_from:
1559+
:param DimensionTag out_spatial_dim:
15481560
:rtype: (Data, Data, int, list[int])
15491561
:return: common, output, axis, input_extra_axes
15501562
"""
@@ -1570,26 +1582,26 @@ def _get_axes(cls, input_data, position, position_axis, output_dim_via_time_from
15701582
assert position_axis != position.batch_dim_axis
15711583
if common.time_dim_axis is None:
15721584
common.time_dim_axis = position_axis
1573-
output_dim = output_dim_via_time_from.batch_shape[output_dim_via_time_from.time_dim_axis]
1574-
output_size = output_dim_via_time_from.size_placeholder.get(
1575-
output_dim_via_time_from.time_dim_axis_excluding_batch, None)
1576-
output = common.copy_template_replace_dim(axis=position_axis, new_dim=output_dim, new_size=output_size)
1585+
output = common.copy_template_replace_dim_tag(axis=position_axis, new_dim_tag=out_spatial_dim)
15771586
return common, output, position_axis, input_extra_axes
15781587

15791588
@classmethod
1580-
def get_out_data_from_opts(cls, name, sources, position, position_axis, output_dim_via_time_from, **kwargs):
1589+
def get_out_data_from_opts(cls, name, sources, position, position_axis,
1590+
output_dim_via_time_from=None, out_spatial_dim=None,
1591+
**kwargs):
15811592
"""
15821593
:param str name:
15831594
:param list[LayerBase] sources:
15841595
:param LayerBase position:
15851596
:param str|int position_axis: axis in `position` to replace by the output-dim
1586-
:param LayerBase output_dim_via_time_from:
1597+
:param LayerBase|None output_dim_via_time_from: use the time-dim from this layer as the output-dim
1598+
:param DimensionTag|None out_spatial_dim:
15871599
:rtype: Data
15881600
"""
15891601
input_data = get_concat_sources_data_template(sources)
15901602
common, output, replace_common_axis, input_extra_axes = cls._get_axes(
15911603
input_data=input_data, position=position.output, position_axis=position_axis,
1592-
output_dim_via_time_from=output_dim_via_time_from.output)
1604+
out_spatial_dim=out_spatial_dim if out_spatial_dim else output_dim_via_time_from.output.get_time_dim_tag())
15931605
return output.copy_template(name="%s_output" % name)
15941606

15951607
@classmethod
@@ -1601,7 +1613,8 @@ def transform_config_dict(cls, d, network, get_layer):
16011613
"""
16021614
super(ScatterNdLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
16031615
d["position"] = get_layer(d["position"])
1604-
d["output_dim_via_time_from"] = get_layer(d["output_dim_via_time_from"])
1616+
if d.get("output_dim_via_time_from", None):
1617+
d["output_dim_via_time_from"] = get_layer(d["output_dim_via_time_from"])
16051618

16061619

16071620
class LinearLayer(_ConcatInputLayer):

0 commit comments

Comments
 (0)