@@ -2718,15 +2718,15 @@ class PadLayer(_ConcatInputLayer):
27182718 """
27192719 layer_class = "pad"
27202720
2721- def __init__ (self , axes , padding , out_spatial_dims = None , value = 0 , mode = "constant" , ** kwargs ):
2721+ def __init__ (self , axes , padding , out_dims = None , value = 0 , mode = "constant" , ** kwargs ):
27222722 """
27232723 :param DimensionTag|str|list[DimensionTag|str] axes: e.g. "F" etc. see :func:`Data.get_axes_from_description`.
27242724 :param list[(int,int)]|(int,int)|int padding: how much to pad left/right in each axis
2725- :param DimensionTag|list[DimensionTag]|None out_spatial_dims :
2725+ :param DimensionTag|list[DimensionTag]|None out_dims :
27262726 :param int|float value: what constant value to pad, with mode=="constant"
27272727 :param str mode: "constant", "reflect", "symmetric" and "replication"
27282728 """
2729- out_spatial_dims # noqa # handled in get_out_data_from_opts
2729+ out_dims # noqa # handled in get_out_data_from_opts
27302730 super (PadLayer , self ).__init__ (** kwargs )
27312731 axes_ = self .input_data .get_axes_from_description (axes )
27322732 assert axes_ , "%s: invalid axes %r in input %s" % (self , axes , self .input_data )
@@ -2781,13 +2781,13 @@ def _transform_padding(cls, padding, axes):
27812781 return padding
27822782
27832783 @classmethod
2784- def get_out_data_from_opts (cls , name , sources , axes , padding , out_spatial_dims = None , ** kwargs ):
2784+ def get_out_data_from_opts (cls , name , sources , axes , padding , out_dims = None , ** kwargs ):
27852785 """
27862786 :param str name:
27872787 :param list[LayerBase] sources:
27882788 :param DimensionTag|str|list[DimensionTag|str] axes:
27892789 :param list[(int,int)]|(int,int)|int padding:
2790- :param DimensionTag|list[DimensionTag]|None out_spatial_dims :
2790+ :param DimensionTag|list[DimensionTag]|None out_dims :
27912791 :rtype: Data
27922792 """
27932793 from ..util .data import DimensionTag
@@ -2800,23 +2800,23 @@ def get_out_data_from_opts(cls, name, sources, axes, padding, out_spatial_dims=N
28002800 else :
28012801 axes = [data .get_axis_from_description (axes )]
28022802 padding = cls ._transform_padding (padding = padding , axes = axes )
2803- if out_spatial_dims :
2804- if isinstance (out_spatial_dims , (list , tuple )):
2805- assert len (out_spatial_dims ) == len (axes ) == len (padding )
2806- assert all (isinstance (d , DimensionTag ) for d in out_spatial_dims )
2803+ if out_dims :
2804+ if isinstance (out_dims , (list , tuple )):
2805+ assert len (out_dims ) == len (axes ) == len (padding )
2806+ assert all (isinstance (d , DimensionTag ) for d in out_dims )
28072807 else :
2808- assert isinstance (out_spatial_dims , DimensionTag )
2808+ assert isinstance (out_dims , DimensionTag )
28092809 assert len (axes ) == len (padding ) == 1
2810- out_spatial_dims = [out_spatial_dims ]
2810+ out_dims = [out_dims ]
28112811 dim_tags = list (data .dim_tags )
28122812 for i , a in enumerate (axes ):
28132813 tag = dim_tags [a ]
28142814 dim = None if tag .dimension is None else (tag .dimension + sum (padding [i ]))
2815- if out_spatial_dims :
2815+ if out_dims :
28162816 if sum (padding [i ]) == 0 :
2817- assert out_spatial_dims [i ] == tag
2817+ assert out_dims [i ] == tag
28182818 continue
2819- tag = out_spatial_dims [i ]
2819+ tag = out_dims [i ]
28202820 assert dim == tag .dimension
28212821 elif sum (padding [i ]) == 0 :
28222822 continue
0 commit comments