Skip to content

Commit 61c8444

Browse files
committed
dim tag get_for_batch_ctx small fixes
1 parent 3266cbb commit 61c8444

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

returnn/tf/layers/rec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,7 +2518,7 @@ def cond(i, net_vars, acc_tas, seq_len_info=None):
25182518
self.parent_rec_layer, input_beam, output_beam,
25192519
self.parent_rec_layer.sources, self.parent_rec_layer.target))
25202520
assert output_template.output.batch.beam == output_beam
2521-
time_dim_tag = time_dim_tag.get_for_batch(
2521+
time_dim_tag = time_dim_tag.get_for_batch_ctx(
25222522
batch=output_template.output.batch, ctx=self.net.control_flow_ctx)
25232523
assert time_dim_tag.dyn_size is not None
25242524
seq_len = time_dim_tag.dyn_size
@@ -2778,7 +2778,7 @@ def get_choice_seq(choice_base):
27782778
latest_batch = (
27792779
latest_layer_choice.output.batch
27802780
or self.parent_rec_layer.output.batch.copy_set_beam(latest_layer_choice.output.beam))
2781-
tag = tag.get_for_batch(batch=latest_batch, ctx=self.net.control_flow_ctx)
2781+
tag = tag.get_for_batch_ctx(batch=latest_batch, ctx=self.net.control_flow_ctx)
27822782
assert tag.dyn_size is not None
27832783
assert tag.batch == latest_batch and tag.batch.beam == latest_layer_choice.output.beam
27842784
seq_len = tag.dyn_size

0 commit comments

Comments
 (0)