Skip to content

Commit c1435a5

Browse files
committed
test_DotLayer_linear_square_matrix
1 parent 74981f1 commit c1435a5

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/test_TFNetworkLayer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4293,6 +4293,41 @@ def test_DotLayer2():
42934293
assert_equal(out.shape, (S1, S2, B, V))
42944294

42954295

4296+
def test_DotLayer_linear_square_matrix():
4297+
from returnn.tf.util.data import batch_dim
4298+
time_dim = SpatialDim("time")
4299+
feat_dim = FeatureDim("feature", dimension=3)
4300+
config = Config({
4301+
"extern_data": {
4302+
"data": {"dim_tags": [batch_dim, time_dim, feat_dim]},
4303+
"matrix_ambiguous": {"dim_tags": [feat_dim, feat_dim], "available_for_inference": True},
4304+
"matrix_non_ambiguous": {
4305+
"dim_tags": [feat_dim.copy(match_priority=1), feat_dim], "available_for_inference": True},
4306+
},
4307+
})
4308+
with make_scope() as session:
4309+
net = TFNetwork(config=config)
4310+
try:
4311+
net.construct_from_dict({
4312+
"output": {
4313+
"class": "dot", "from": ["data:data", "data:matrix_ambiguous"], "reduce": feat_dim
4314+
},
4315+
})
4316+
except Exception as exc:
4317+
print("Expected exception: %r" % exc)
4318+
assert "must be unique" in str(exc)
4319+
else:
4320+
raise Exception("Expected exception but constructed layer: %s" % net.get_default_output_layer())
4321+
net.construct_from_dict({
4322+
"output": {
4323+
"class": "dot", "from": ["data:data", "data:matrix_non_ambiguous"], "reduce": feat_dim
4324+
},
4325+
})
4326+
out = net.get_default_output_layer().output
4327+
assert out.dim_tags == (batch_dim, time_dim, feat_dim)
4328+
session.run(out.placeholder, feed_dict=make_feed_dict(net.extern_data))
4329+
4330+
42964331
def test_DotLayer_mask_dyn_seq():
42974332
batch = Dim(kind=Dim.Types.Batch, description="batch")
42984333
time = SpatialDim("time")

0 commit comments

Comments
 (0)