Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _base_get_out_data_from_opts(cls, network, name,
out_type=None, out_dim=None, n_out=NotSpecified,
out_shape=None,
target=None, _target_layers=None, size_target=None,
sources=(), loss=None,
sources=(), in_dim=None, loss=None,
**kwargs):
"""
Called via BaseLayer.get_out_data_from_opts().
Expand All @@ -313,6 +313,7 @@ def _base_get_out_data_from_opts(cls, network, name,
:param dict[str,LayerBase]|None _target_layers: if target.startswith("layer:"), then this is target -> layer
:param str|None size_target:
:param list[LayerBase] sources:
:param DimensionTag|None in_dim:
:param Loss|None loss:
:param kwargs: remaining kwargs of self.__init__(), ignored here
:return: Data template (placeholder not set)
Expand All @@ -338,6 +339,15 @@ def _base_get_out_data_from_opts(cls, network, name,
if n_out is not NotSpecified:
assert out_type["dim"] == n_out
sources_data_list = [src.output for src in sources if src]
if in_dim:
assert len(sources_data_list) == 1
if sources_data_list[0].feature_dim_or_sparse_dim != in_dim:
# Allow to specify some in_dim which is not the feature dim.
# However, the follow-up code will expect it to be the feature dim, thus reassign it if possible.
assert in_dim in sources_data_list[0].dim_tags
axis = sources_data_list[0].get_axis_from_description(in_dim)
sources_data_list = [sources_data_list[0].copy()]
sources_data_list[0].feature_dim_axis = axis
allow_broadcast_all_sources = NotSpecified
if "shape" in out_type or "dim_tags" in out_type or out_shape is not None:
allow_broadcast_all_sources = True
Expand Down
12 changes: 11 additions & 1 deletion returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,17 @@ def concat_sources(src_layers, out_dim=None, allow_broadcast_all_sources=NotSpec
if len(src_layers) == 1:
data = src_layers[0].output.copy()
if out_dim:
assert out_dim == data.feature_dim_or_sparse_dim
if out_dim == data.feature_dim_or_sparse_dim:
pass # good
elif out_dim in data.dim_tags:
# We found out_dim in the input but it is not marked as the feature dim.
# This is explicitly allowed. Follow-up code will expect this to be the feature-dim though,
# So we mark it accordingly.
assert not data.sparse
axis = data.get_axis_from_description(out_dim)
data.feature_dim_axis = axis
else:
raise Exception("%s not found in %s" % (out_dim, data))
return data
network = src_layers[0].network
cache_key = (tuple(src_layers), out_dim, 0.0, None)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,32 @@ def test_LinearLayer():
session.run(net.get_default_output_layer().output.placeholder, feed_dict=make_feed_dict(net.extern_data))


def test_LinearLayer_in_dim_spatial():
from returnn.tf.util.data import BatchDim
time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="time")
static_spatial_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="static-spatial", dimension=3)
feat_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="in-feature", dimension=5)
out_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="out-feature", dimension=7)
config = Config({
"extern_data": {
"data": {"dim_tags": [BatchDim, time_dim, static_spatial_dim, feat_dim]} # [B,T,D1,D2]
}
})
for _ in range(2):
with make_scope() as session:
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {"class": "linear", "from": "data", "in_dim": static_spatial_dim, "out_dim": out_dim}})
layer = net.get_default_output_layer()
print("Output:", layer.output)
assert layer.output.dim_tags_set_implicit == {BatchDim, time_dim, out_dim, feat_dim}
param = layer.params["W"]
assert isinstance(param, tf.Variable)
assert param.shape.as_list() == [static_spatial_dim.dimension, out_dim.dimension]
session.run(tf_compat.v1.global_variables_initializer())
session.run(layer.output.placeholder, feed_dict=make_feed_dict(net.extern_data))


def test_LinearLayer_two_time_dims_allow_broadcast_all_sources():
from returnn.tf.util.data import BatchDim
with make_scope() as session:
Expand Down