Skip to content

Commit 5c6e440

Browse files
authored
PadLayer, rename out_spatial_dims to out_dims (#779)
#597
1 parent 95844ea commit 5c6e440

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

returnn/tf/layers/basic.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2718,15 +2718,15 @@ class PadLayer(_ConcatInputLayer):
27182718
"""
27192719
layer_class = "pad"
27202720

2721-
def __init__(self, axes, padding, out_spatial_dims=None, value=0, mode="constant", **kwargs):
2721+
def __init__(self, axes, padding, out_dims=None, value=0, mode="constant", **kwargs):
27222722
"""
27232723
:param DimensionTag|str|list[DimensionTag|str] axes: e.g. "F" etc. see :func:`Data.get_axes_from_description`.
27242724
:param list[(int,int)]|(int,int)|int padding: how much to pad left/right in each axis
2725-
:param DimensionTag|list[DimensionTag]|None out_spatial_dims:
2725+
:param DimensionTag|list[DimensionTag]|None out_dims:
27262726
:param int|float value: what constant value to pad, with mode=="constant"
27272727
:param str mode: "constant", "reflect", "symmetric" and "replication"
27282728
"""
2729-
out_spatial_dims # noqa # handled in get_out_data_from_opts
2729+
out_dims # noqa # handled in get_out_data_from_opts
27302730
super(PadLayer, self).__init__(**kwargs)
27312731
axes_ = self.input_data.get_axes_from_description(axes)
27322732
assert axes_, "%s: invalid axes %r in input %s" % (self, axes, self.input_data)
@@ -2781,13 +2781,13 @@ def _transform_padding(cls, padding, axes):
27812781
return padding
27822782

27832783
@classmethod
2784-
def get_out_data_from_opts(cls, name, sources, axes, padding, out_spatial_dims=None, **kwargs):
2784+
def get_out_data_from_opts(cls, name, sources, axes, padding, out_dims=None, **kwargs):
27852785
"""
27862786
:param str name:
27872787
:param list[LayerBase] sources:
27882788
:param DimensionTag|str|list[DimensionTag|str] axes:
27892789
:param list[(int,int)]|(int,int)|int padding:
2790-
:param DimensionTag|list[DimensionTag]|None out_spatial_dims:
2790+
:param DimensionTag|list[DimensionTag]|None out_dims:
27912791
:rtype: Data
27922792
"""
27932793
from ..util.data import DimensionTag
@@ -2800,23 +2800,23 @@ def get_out_data_from_opts(cls, name, sources, axes, padding, out_spatial_dims=N
28002800
else:
28012801
axes = [data.get_axis_from_description(axes)]
28022802
padding = cls._transform_padding(padding=padding, axes=axes)
2803-
if out_spatial_dims:
2804-
if isinstance(out_spatial_dims, (list, tuple)):
2805-
assert len(out_spatial_dims) == len(axes) == len(padding)
2806-
assert all(isinstance(d, DimensionTag) for d in out_spatial_dims)
2803+
if out_dims:
2804+
if isinstance(out_dims, (list, tuple)):
2805+
assert len(out_dims) == len(axes) == len(padding)
2806+
assert all(isinstance(d, DimensionTag) for d in out_dims)
28072807
else:
2808-
assert isinstance(out_spatial_dims, DimensionTag)
2808+
assert isinstance(out_dims, DimensionTag)
28092809
assert len(axes) == len(padding) == 1
2810-
out_spatial_dims = [out_spatial_dims]
2810+
out_dims = [out_dims]
28112811
dim_tags = list(data.dim_tags)
28122812
for i, a in enumerate(axes):
28132813
tag = dim_tags[a]
28142814
dim = None if tag.dimension is None else (tag.dimension + sum(padding[i]))
2815-
if out_spatial_dims:
2815+
if out_dims:
28162816
if sum(padding[i]) == 0:
2817-
assert out_spatial_dims[i] == tag
2817+
assert out_dims[i] == tag
28182818
continue
2819-
tag = out_spatial_dims[i]
2819+
tag = out_dims[i]
28202820
assert dim == tag.dimension
28212821
elif sum(padding[i]) == 0:
28222822
continue

0 commit comments

Comments
 (0)