Skip to content

Commit 897c77c

Browse files
committed
ConstantLayer, shape option to define dim tags
#597
1 parent f35b0a3 commit 897c77c

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

returnn/tf/layers/basic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,10 +2330,11 @@ class ConstantLayer(LayerBase):
23302330
layer_class = "constant"
23312331

23322332
# noinspection PyUnusedLocal
2333-
def __init__(self, sources, value=0., dtype=None, with_batch_dim=False, **kwargs):
2333+
def __init__(self, sources, value=0., shape=None, dtype=None, with_batch_dim=False, **kwargs):
23342334
"""
23352335
:param list[LayerBase] sources:
23362336
:param int|float|bool value:
2337+
:param tuple[DimensionTag|int]|list[DimensionTag|int] shape: for verification, and defining dim tags
23372338
:param str|None dtype:
23382339
:param bool with_batch_dim:
23392340
"""
@@ -2357,15 +2358,17 @@ def transform_config_dict(cls, d, network, get_layer):
23572358
super(ConstantLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
23582359

23592360
@classmethod
2360-
def get_out_data_from_opts(cls, name, value=0., dtype=None, with_batch_dim=False, **kwargs):
2361+
def get_out_data_from_opts(cls, name, value=0., shape=None, dtype=None, with_batch_dim=False, **kwargs):
23612362
"""
23622363
:param str name:
23632364
:param int|float|bool value:
2365+
:param tuple[DimensionTag|int]|list[DimensionTag|int] shape: for verification, and defining dim tags
23642366
:param str|None dtype:
23652367
:param bool with_batch_dim:
23662368
:rtype: Data
23672369
"""
2368-
return Data.template_from_constant(value, name="%s_const" % name, dtype=dtype, with_batch_dim=with_batch_dim)
2370+
return Data.template_from_constant(
2371+
value, name="%s_const" % name, shape=shape, dtype=dtype, with_batch_dim=with_batch_dim)
23692372

23702373

23712374
class GatingLayer(_ConcatInputLayer):

returnn/tf/util/data.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,11 +1731,12 @@ def from_tensor(cls, x):
17311731
return Data(name=str(x.op.name), shape=(), batch_dim_axis=None, dtype=x.dtype.name, placeholder=x)
17321732

17331733
@classmethod
1734-
def template_from_constant(cls, x, name, dtype=None, with_batch_dim=False):
1734+
def template_from_constant(cls, x, name, dtype=None, shape=None, with_batch_dim=False):
17351735
"""
17361736
:param int|float|bool|numpy.ndarray x:
17371737
:param str name:
17381738
:param str|None dtype:
1739+
:param list[DimensionTag|int]|tuple[DimensionTag|int]|None shape: for verification, and defining dim tags
17391740
:param bool with_batch_dim:
17401741
:rtype: Data
17411742
"""
@@ -1750,12 +1751,29 @@ def template_from_constant(cls, x, name, dtype=None, with_batch_dim=False):
17501751
elif isinstance(x, numpy.ndarray):
17511752
dtype = str(x.dtype)
17521753
else:
1753-
raise TypeError("cannot handle value %r of type %r" % (x, type(x)))
1754-
shape = x.shape if isinstance(x, numpy.ndarray) else ()
1755-
return Data(
1756-
name=name,
1757-
shape=shape, batch_dim_axis=0 if with_batch_dim else None, time_dim_axis=None,
1758-
dtype=dtype)
1754+
raise TypeError("%r: cannot handle value %r of type %r" % (name, x, type(x)))
1755+
shape_ = x.shape if isinstance(x, numpy.ndarray) else ()
1756+
if shape is not None:
1757+
assert len(shape) == len(shape_), "%r: shape does not match in ndim, %r vs %r" % (name, shape, shape_)
1758+
else:
1759+
shape = shape_
1760+
dim_tags = []
1761+
for i, (d, d_) in enumerate(zip(shape, shape_)):
1762+
assert isinstance(d_, int)
1763+
if isinstance(d, DimensionTag):
1764+
assert d.dimension == d_
1765+
elif isinstance(d, int):
1766+
assert d == d_
1767+
d = DimensionTag(
1768+
kind=DimensionTag.Types.Spatial if i < len(shape) - 1 else DimensionTag.Types.Feature,
1769+
description="%s:static:%i" % (name, i),
1770+
dimension=d)
1771+
else:
1772+
raise TypeError("%r shape[%i] invalid type %r in shape %r" % (name, i, type(d), shape))
1773+
dim_tags.append(d)
1774+
if with_batch_dim:
1775+
dim_tags.insert(0, BatchDim)
1776+
return Data(name=name, dim_tags=dim_tags, dtype=dtype)
17591777

17601778
def sanity_check(self, ignore_placeholder=False, assume_complete=True):
17611779
"""

0 commit comments

Comments
 (0)